Theory Bits_Natural

theory Bits_Natural
imports Refine_Monadic Native_Word_Imperative_HOL Code_Target_Bits_Int Uint32 Uint64 More_Word
theory Bits_Natural
  imports
    Refine_Monadic.Refine_Monadic
    Native_Word.Native_Word_Imperative_HOL
    Native_Word.Code_Target_Bits_Int Native_Word.Uint32 Native_Word.Uint64
     "HOL-Word.More_Word"
begin

instantiation nat :: bit_comprehension
begin

definition test_bit_nat :: ‹nat ⇒ nat ⇒ bool› where
  "test_bit i j = test_bit (int i) j"

definition lsb_nat :: ‹nat ⇒ bool› where
  "lsb i = (int i :: int) !! 0"

definition set_bit_nat :: "nat ⇒ nat ⇒ bool ⇒ nat" where
  "set_bit i n b = nat (bin_sc n b (int i))"

definition set_bits_nat :: "(nat ⇒ bool) ⇒ nat" where
  "set_bits f =
  (if ∃n. ∀n'≥n. ¬ f n' then
     let n = LEAST n. ∀n'≥n. ¬ f n'
     in nat (bl_to_bin (rev (map f [0..<n])))
   else if ∃n. ∀n'≥n. f n' then
     let n = LEAST n. ∀n'≥n. f n'
     in nat (sbintrunc n (bl_to_bin (True # rev (map f [0..<n]))))
   else 0 :: nat)"

definition shiftl_nat where
  "shiftl x n = nat ((int x) * 2 ^ n)"

definition shiftr_nat where
  "shiftr x n = nat (int x div 2 ^ n)"

definition bitNOT_nat :: "nat ⇒ nat" where
  "bitNOT i = nat (bitNOT (int i))"

definition bitAND_nat :: "nat ⇒ nat ⇒ nat" where
  "bitAND i j = nat (bitAND (int i) (int j))"

definition bitOR_nat :: "nat ⇒ nat ⇒ nat" where
  "bitOR i j = nat (bitOR (int i) (int j))"

definition bitXOR_nat :: "nat ⇒ nat ⇒ nat" where
  "bitXOR i j = nat (bitXOR (int i) (int j))"

instance ..

end


lemma nat_shiftr[simp]:
  "m >> 0 = m"
  ‹((0::nat) >> m) = 0›
  ‹(m >> Suc n) = (m div 2 >> n)› for m :: nat
  by (auto simp: shiftr_nat_def zdiv_int zdiv_zmult2_eq[symmetric])

lemma nat_shifl_div: ‹m >> n = m div (2^n)› for m :: nat
  by (induction n arbitrary: m) (auto simp: div_mult2_eq)

lemma nat_shiftl[simp]:
  "m << 0 = m"
  ‹((0::nat) << m) = 0›
  ‹(m << Suc n) = ((m * 2) << n)› for m :: nat
  by (auto simp: shiftl_nat_def zdiv_int zdiv_zmult2_eq[symmetric])

lemma nat_shiftr_div2: ‹m >> 1 = m div 2› for m :: nat
  by auto

lemma nat_shiftr_div: ‹m << n = m * (2^n)› for m :: nat
  by (induction n arbitrary: m) (auto simp: div_mult2_eq)

definition shiftl1 :: ‹nat ⇒ nat› where
  ‹shiftl1 n = n << 1›

definition shiftr1 :: ‹nat ⇒ nat› where
  ‹shiftr1 n = n >> 1›

instantiation natural :: bit_comprehension
begin

context includes natural.lifting begin

lift_definition test_bit_natural :: ‹natural ⇒ nat ⇒ bool› is test_bit .

lift_definition lsb_natural :: ‹natural ⇒ bool› is lsb .

lift_definition set_bit_natural :: "natural ⇒ nat ⇒ bool ⇒ natural" is
  set_bit .

lift_definition set_bits_natural :: ‹(nat ⇒ bool) ⇒ natural›
  is ‹set_bits :: (nat ⇒ bool) ⇒ nat› .

lift_definition shiftl_natural :: ‹natural ⇒ nat ⇒ natural›
  is ‹shiftl :: nat ⇒ nat ⇒ nat› .

lift_definition shiftr_natural :: ‹natural ⇒ nat ⇒ natural›
  is ‹shiftr :: nat ⇒ nat ⇒ nat› .

lift_definition bitNOT_natural :: ‹natural ⇒ natural›
  is ‹bitNOT :: nat ⇒ nat› .

lift_definition bitAND_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitAND :: nat ⇒ nat ⇒ nat› .

lift_definition bitOR_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitOR :: nat ⇒ nat ⇒ nat› .

lift_definition bitXOR_natural :: ‹natural ⇒ natural ⇒ natural›
  is ‹bitXOR :: nat ⇒ nat ⇒ nat› .

end

instance ..
end

context includes natural.lifting begin
lemma [code]:
  "integer_of_natural (m >> n) = (integer_of_natural m) >> n"
  apply transfer
  by (smt integer_of_natural.rep_eq msb_int_def msb_shiftr nat_eq_iff2 negative_zle
      shiftr_int_code shiftr_int_def shiftr_nat_def shiftr_natural.rep_eq
      type_definition.Rep_inject type_definition_integer)

lemma [code]:
  "integer_of_natural (m << n) = (integer_of_natural m) << n"
  apply transfer
  by (smt integer_of_natural.rep_eq msb_int_def msb_shiftl nat_eq_iff2 negative_zle
      shiftl_int_code shiftl_int_def shiftl_nat_def shiftl_natural.rep_eq
      type_definition.Rep_inject type_definition_integer)

end


lemma bitXOR_1_if_mod_2: ‹bitXOR L 1 = (if L mod 2 = 0 then L + 1 else L - 1)› for L :: nat
  apply transfer
  apply (subst int_int_eq[symmetric])
  apply (rule bin_rl_eqI)
   apply (auto simp: bitXOR_nat_def)
  unfolding bin_rest_def bin_last_def bitXOR_nat_def
       apply presburger+
  done

lemma bitAND_1_mod_2: ‹bitAND L 1 = L mod 2› for L :: nat
  apply transfer
  apply (subst int_int_eq[symmetric])
  apply (subst bitAND_nat_def)
  by (auto simp: zmod_int bin_rest_def bin_last_def bitval_bin_last[symmetric])

lemma shiftl_0_uint32[simp]: ‹n << 0 = n› for n :: uint32
  by transfer auto

lemma shiftl_Suc_uint32: ‹n << Suc m = (n << m) << 1› for n :: uint32
  apply transfer
  apply transfer
  by auto


lemma nat_set_bit_0: ‹set_bit x 0 b = nat ((bin_rest (int x)) BIT b)› for x :: nat
  by (auto simp: set_bit_nat_def)

lemma nat_test_bit0_iff: ‹n !! 0 ⟷ n mod 2 = 1› for n :: nat
proof -
  have 2: ‹2 = int 2›
    by auto
  have [simp]: ‹int n mod 2 = 1 ⟷ n mod 2 = Suc 0›
    unfolding 2 zmod_int[symmetric]
    by auto

  show ?thesis
    unfolding test_bit_nat_def
    by (auto simp: bin_last_def zmod_int)
qed
lemma test_bit_2: ‹m > 0 ⟹ (2*n) !! m ⟷ n !! (m - 1)› for n :: nat
  by (cases m)
    (auto simp: test_bit_nat_def bin_rest_def)

lemma test_bit_Suc_2: ‹m > 0 ⟹ Suc (2 * n) !! m ⟷ (2 * n) !! m› for n :: nat
  by (cases m)
    (auto simp: test_bit_nat_def bin_rest_def)

lemma bin_rest_prev_eq:
  assumes [simp]: ‹m > 0›
  shows  ‹nat ((bin_rest (int w))) !! (m - Suc (0::nat)) = w !! m›
proof -
  define m' where ‹m' = w div 2›
  have w: ‹w = 2 * m' ∨ w = Suc (2 * m')›
    unfolding m'_def
    by auto
  moreover have ‹bin_nth (int m') (m - Suc 0) = m' !! (m - Suc 0)›
    unfolding test_bit_nat_def test_bit_int_def ..
  ultimately show ?thesis
    by (auto simp: bin_rest_def test_bit_2 test_bit_Suc_2)
qed

lemma bin_sc_ge0: ‹w >= 0 ==> (0::int) ≤ bin_sc n b w›
  by (induction n arbitrary: w) auto

lemma bin_to_bl_eq_nat:
  ‹bin_to_bl (size a) (int a) = bin_to_bl (size b) (int b) ==> a=b›
  by (metis Nat.size_nat_def size_bin_to_bl)

lemma nat_bin_nth_bl: "n < m ⟹ w !! n = nth (rev (bin_to_bl m (int w))) n" for w :: nat
  apply (induct n arbitrary: m w)
  subgoal for m w
    apply clarsimp
    apply (case_tac m, clarsimp)
    using bin_nth_bl bin_to_bl_def test_bit_int_def test_bit_nat_def apply presburger
    done
  subgoal for n m w
    apply (clarsimp simp: bin_to_bl_def)
    apply (case_tac m, clarsimp)
    apply (clarsimp simp: bin_to_bl_def)
    apply (subst bin_to_bl_aux_alt)
    apply (simp add: bin_nth_bl test_bit_nat_def)
    done
  done

lemma bin_nth_ge_size: ‹nat na ≤ n ⟹ 0 ≤ na ⟹ bin_nth na n = False›
proof (induction ‹n› arbitrary: na)
  case 0
  then show ?case by auto
next
  case (Suc n na) note IH = this(1) and H = this(2-)
  have ‹na = 1 ∨ 0 ≤ na div 2›
    using H by auto
  moreover have
    ‹na = 0 ∨ na = 1 ∨ nat (na div 2) ≤ n›
    using H by auto
  ultimately show ?case
    using IH[rule_format,  of ‹bin_rest na›] H
    by (auto simp: bin_rest_def)
qed

lemma test_bit_nat_outside: "n > size w ⟹ ¬w !! n" for w :: nat
  unfolding test_bit_nat_def
  by (auto simp: bin_nth_ge_size)

lemma nat_bin_nth_bl':
  ‹a !! n ⟷ (n < size a ∧ (rev (bin_to_bl (size a) (int a)) ! n))›
  by (metis (full_types) Nat.size_nat_def bin_nth_ge_size leI nat_bin_nth_bl nat_int
      of_nat_less_0_iff test_bit_int_def test_bit_nat_def)

lemma nat_set_bit_test_bit: ‹set_bit w n x !! m = (if m = n then x else w !! m)› for w n :: nat
  unfolding nat_bin_nth_bl'
  apply auto
        apply (metis bin_nth_bl bin_nth_sc bin_nth_simps(3) bin_to_bl_def int_nat_eq set_bit_nat_def)
       apply (metis bin_nth_ge_size bin_nth_sc bin_sc_ge0 leI of_nat_less_0_iff set_bit_nat_def)
      apply (metis bin_nth_bl bin_nth_ge_size bin_nth_sc bin_sc_ge0 bin_to_bl_def int_nat_eq leI
      of_nat_less_0_iff set_bit_nat_def)
     apply (metis Nat.size_nat_def bin_nth_sc_gen bin_nth_simps(3) bin_to_bl_def int_nat_eq
      nat_bin_nth_bl' set_bit_nat_def test_bit_int_def test_bit_nat_def)
    apply (metis Nat.size_nat_def bin_nth_bl bin_nth_sc_gen bin_to_bl_def int_nat_eq nat_bin_nth_bl
      nat_bin_nth_bl' of_nat_less_0_iff of_nat_less_iff set_bit_nat_def)
   apply (metis (full_types) bin_nth_bl bin_nth_ge_size bin_nth_sc_gen bin_sc_ge0 bin_to_bl_def leI of_nat_less_0_iff set_bit_nat_def)
  by (metis bin_nth_bl bin_nth_ge_size bin_nth_sc_gen bin_sc_ge0 bin_to_bl_def int_nat_eq leI of_nat_less_0_iff set_bit_nat_def)

end