Theory WB_Word_Assn

theory WB_Word_Assn
imports WB_Word WB_More_IICF_SML
theory WB_Word_Assn
imports  Refine_Imperative_HOL.IICF
  WB_Word Bits_Natural
  WB_More_Refinement WB_More_IICF_SML
begin

subsection ‹More Setup for Fixed Size Natural Numbers›

subsubsection ‹Words›

abbreviation word_nat_assn :: "nat ⇒ 'a::len0 Word.word ⇒ assn" where
  ‹word_nat_assn ≡ pure word_nat_rel›

lemma op_eq_word_nat:
  ‹(uncurry (return oo ((=) :: 'a :: len Word.word ⇒ _)), uncurry (RETURN oo (=))) ∈
    word_nat_assnk *a word_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: word_nat_rel_def br_def)


abbreviation uint32_nat_assn :: "nat ⇒ uint32 ⇒ assn" where
  ‹uint32_nat_assn ≡ pure uint32_nat_rel›

lemma op_eq_uint32_nat[sepref_fr_rules]:
  ‹(uncurry (return oo ((=) :: uint32 ⇒ _)), uncurry (RETURN oo (=))) ∈
    uint32_nat_assnk *a uint32_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

abbreviation uint32_assn :: ‹uint32 ⇒ uint32 ⇒ assn› where
  ‹uint32_assn ≡ id_assn›

lemma op_eq_uint32:
  ‹(uncurry (return oo ((=) :: uint32 ⇒ _)), uncurry (RETURN oo (=))) ∈
    uint32_assnk *a uint32_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemmas [id_rules] =
  itypeI[Pure.of 0 "TYPE (uint32)"]
  itypeI[Pure.of 1 "TYPE (uint32)"]

lemma param_uint32[param, sepref_import_param]:
  "(0, 0::uint32) ∈ Id"
  "(1, 1::uint32) ∈ Id"
  by (rule IdI)+

lemma param_max_uint32[param,sepref_import_param]:
  "(max,max)∈uint32_rel → uint32_rel → uint32_rel" by auto

lemma max_uint32[sepref_fr_rules]:
  ‹(uncurry (return oo max), uncurry (RETURN oo max)) ∈
    uint32_assnk *a uint32_assnka uint32_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma uint32_nat_assn_minus:
  ‹(uncurry (return oo uint32_safe_minus), uncurry (RETURN oo (-))) ∈
     uint32_nat_assnk *a uint32_nat_assnka uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def nat_of_uint32_le_minus
      br_def uint32_safe_minus_def nat_of_uint32_notle_minus)

lemma [safe_constraint_rules]:
  ‹CONSTRAINT IS_LEFT_UNIQUE uint32_nat_rel›
  ‹CONSTRAINT IS_RIGHT_UNIQUE uint32_nat_rel›
  by (auto simp: IS_LEFT_UNIQUE_def single_valued_def uint32_nat_rel_def br_def)

lemma shiftr1[sepref_fr_rules]:
   ‹(uncurry (return oo ((>>))), uncurry (RETURN oo (>>))) ∈ uint32_assnk *a nat_assnka
      uint32_assn›
  by sepref_to_hoare (sep_auto simp: shiftr1_def uint32_nat_rel_def br_def)

lemma shiftl1[sepref_fr_rules]: ‹(return o shiftl1, RETURN o shiftl1) ∈ nat_assnka nat_assn›
  by sepref_to_hoare sep_auto

lemma nat_of_uint32_rule[sepref_fr_rules]:
  ‹(return o nat_of_uint32, RETURN o nat_of_uint32) ∈ uint32_assnka nat_assn›
  by sepref_to_hoare sep_auto


lemma max_uint32_nat[sepref_fr_rules]:
  ‹(uncurry (return oo max), uncurry (RETURN oo max)) ∈ uint32_nat_assnk *a uint32_nat_assnka
     uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_max)

lemma array_set_hnr_u:
    ‹CONSTRAINT is_pure A ⟹
    (uncurry2 (λxs i. heap_array_set xs (nat_of_uint32 i)), uncurry2 (RETURN ∘∘∘ op_list_set)) ∈
     [pre_list_set]a (array_assn A)d *a uint32_nat_assnk *a Ak → array_assn A›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def ex_assn_up_eq2 array_assn_def is_array_def
      hr_comp_def list_rel_pres_length list_rel_update)

lemma array_get_hnr_u:
  assumes ‹CONSTRAINT is_pure A›
  shows ‹(uncurry (λxs i. Array.nth xs (nat_of_uint32 i)),
      uncurry (RETURN ∘∘ op_list_get)) ∈ [pre_list_get]a (array_assn A)k *a uint32_nat_assnk → A›
proof -
  obtain A' where
    A: ‹pure A' = A›
    using assms pure_the_pure by auto
  then have A': ‹the_pure A = A'›
    by auto
  have [simp]: ‹the_pure (λa c. ↑ ((c, a) ∈ A')) = A'›
    unfolding pure_def[symmetric] by auto
  show ?thesis
    by sepref_to_hoare
      (sep_auto simp: uint32_nat_rel_def br_def ex_assn_up_eq2 array_assn_def is_array_def
       hr_comp_def list_rel_pres_length list_rel_update param_nth A' A[symmetric] ent_refl_true
     list_rel_eq_listrel listrel_iff_nth pure_def)
qed

lemma arl_get_hnr_u:
  assumes ‹CONSTRAINT is_pure A›
  shows ‹(uncurry (λxs i. arl_get xs (nat_of_uint32 i)), uncurry (RETURN ∘∘ op_list_get))
∈ [pre_list_get]a (arl_assn A)k *a uint32_nat_assnk → A›
proof -
  obtain A' where
    A: ‹pure A' = A›
    using assms pure_the_pure by auto
  then have A': ‹the_pure A = A'›
    by auto
  have [simp]: ‹the_pure (λa c. ↑ ((c, a) ∈ A')) = A'›
    unfolding pure_def[symmetric] by auto
  show ?thesis
    by sepref_to_hoare
      (sep_auto simp: uint32_nat_rel_def br_def ex_assn_up_eq2 array_assn_def is_array_def
        hr_comp_def list_rel_pres_length list_rel_update param_nth arl_assn_def
        A' A[symmetric] pure_def)
qed

lemma uint32_nat_assn_plus[sepref_fr_rules]:
  ‹(uncurry (return oo (+)), uncurry (RETURN oo (+))) ∈ [λ(m, n). m + n ≤ uint32_max]a
     uint32_nat_assnk *a uint32_nat_assnk → uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def nat_of_uint32_add br_def)


lemma uint32_nat_assn_one:
  ‹(uncurry0 (return 1), uncurry0 (RETURN 1)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma uint32_nat_assn_zero:
  ‹(uncurry0 (return 0), uncurry0 (RETURN 0)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma nat_of_uint32_int32_assn:
  ‹(return o id, RETURN o nat_of_uint32) ∈ uint32_assnka uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)


lemma uint32_nat_assn_zero_uint32_nat[sepref_fr_rules]:
  ‹(uncurry0 (return 0), uncurry0 (RETURN zero_uint32_nat)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma nat_assn_zero:
  ‹(uncurry0 (return 0), uncurry0 (RETURN 0)) ∈ unit_assnka nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma one_uint32_nat[sepref_fr_rules]:
  ‹(uncurry0 (return 1), uncurry0 (RETURN one_uint32_nat)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def)

lemma uint32_nat_assn_less[sepref_fr_rules]:
  ‹(uncurry (return oo (<)), uncurry (RETURN oo (<))) ∈
    uint32_nat_assnk *a uint32_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def max_def
      nat_of_uint32_less_iff)

lemma uint32_2_hnr[sepref_fr_rules]: ‹(uncurry0 (return two_uint32), uncurry0 (RETURN two_uint32_nat)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def two_uint32_nat_def)


text ‹Do NOT declare this theorem as ‹sepref_fr_rules› to avoid bad unexpected conversions.›
lemma le_uint32_nat_hnr:
  ‹(uncurry (return oo (λa b. nat_of_uint32 a < b)), uncurry (RETURN oo (<))) ∈
   uint32_nat_assnk *a nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma le_nat_uint32_hnr:
  ‹(uncurry (return oo (λa b. a < nat_of_uint32 b)), uncurry (RETURN oo (<))) ∈
   nat_assnk *a uint32_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

code_printing constant fast_minus_nat'  (SML_imp) "(Nat(integer'_of'_nat/ (_)/ -/ integer'_of'_nat/ (_)))"

lemma fast_minus_nat[sepref_fr_rules]:
  ‹(uncurry (return oo fast_minus_nat), uncurry (RETURN oo fast_minus)) ∈
     [λ(m, n). m ≥ n]a nat_assnk *a nat_assnk → nat_assn›
  by sepref_to_hoare
   (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_le_minus
      nat_of_uint32_notle_minus nat_of_uint32_le_iff)

definition fast_minus_uint32 :: ‹uint32 ⇒ uint32 ⇒ uint32› where
  [simp]: ‹fast_minus_uint32 = fast_minus›

lemma fast_minus_uint32[sepref_fr_rules]:
  ‹(uncurry (return oo fast_minus_uint32), uncurry (RETURN oo fast_minus)) ∈
     [λ(m, n). m ≥ n]a uint32_nat_assnk *a uint32_nat_assnk → uint32_nat_assn›
  by sepref_to_hoare
   (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_le_minus
      nat_of_uint32_notle_minus nat_of_uint32_le_iff)

lemma uint32_nat_assn_0_eq: ‹uint32_nat_assn 0 a = ↑ (a = 0)›
  by (auto simp: uint32_nat_rel_def br_def pure_def nat_of_uint32_0_iff)

lemma uint32_nat_assn_nat_assn_nat_of_uint32:
   ‹uint32_nat_assn aa a = nat_assn aa (nat_of_uint32 a)›
  by (auto simp: pure_def uint32_nat_rel_def br_def)

lemma sum_mod_uint32_max: ‹(uncurry (return oo (+)), uncurry (RETURN oo sum_mod_uint32_max)) ∈
  uint32_nat_assnk *a uint32_nat_assnka
  uint32_nat_assn›
  by sepref_to_hoare
     (sep_auto simp: sum_mod_uint32_max_def uint32_nat_rel_def br_def nat_of_uint32_plus)

lemma le_uint32_nat_rel_hnr[sepref_fr_rules]:
  ‹(uncurry (return oo (≤)), uncurry (RETURN oo (≤))) ∈
   uint32_nat_assnk *a uint32_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_le_iff)

lemma one_uint32_hnr[sepref_fr_rules]:
  ‹(uncurry0 (return 1), uncurry0 (RETURN one_uint32)) ∈ unit_assnka uint32_assn›
  by sepref_to_hoare (sep_auto simp: one_uint32_def)

lemma sum_uint32_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (+)), uncurry (RETURN oo (+))) ∈ uint32_assnk *a uint32_assnka uint32_assn›
  by sepref_to_hoare sep_auto

lemma Suc_uint32_nat_assn_hnr:
  ‹(return o (λn. n + 1), RETURN o Suc) ∈ [λn. n < uint32_max]a uint32_nat_assnk → uint32_nat_assn›
  by sepref_to_hoare (sep_auto simp: br_def uint32_nat_rel_def nat_of_uint32_add)

lemma minus_uint32_assn:
 ‹(uncurry (return oo (-)), uncurry (RETURN oo (-))) ∈ uint32_assnk *a uint32_assnka uint32_assn›
 by sepref_to_hoare sep_auto

lemma bitAND_uint32_nat_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (AND)), uncurry (RETURN oo (AND))) ∈
    uint32_nat_assnk *a uint32_nat_assnka uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_ao)

lemma bitAND_uint32_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (AND)), uncurry (RETURN oo (AND))) ∈
    uint32_assnk *a uint32_assnka uint32_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_ao)

lemma bitOR_uint32_nat_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (OR)), uncurry (RETURN oo (OR))) ∈
    uint32_nat_assnk *a uint32_nat_assnka uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_ao)

lemma bitOR_uint32_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (OR)), uncurry (RETURN oo (OR))) ∈
    uint32_assnk *a uint32_assnka uint32_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_ao)

lemma uint32_nat_assn_mult:
  ‹(uncurry (return oo ((*))), uncurry (RETURN oo ((*)))) ∈ [λ(a, b). a * b ≤ uint32_max]a
      uint32_nat_assnk *a uint32_nat_assnk → uint32_nat_assn›
  by sepref_to_hoare
     (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_mult_le)


lemma [sepref_fr_rules]:
  ‹(uncurry (return oo (div)), uncurry (RETURN oo (div))) ∈
     uint32_nat_assnk *a uint32_nat_assnka uint32_nat_assn›
  by sepref_to_hoare
   (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_div)


subsubsection ‹64-bits›

lemmas [id_rules] =
  itypeI[Pure.of 0 "TYPE (uint64)"]
  itypeI[Pure.of 1 "TYPE (uint64)"]

lemma param_uint64[param, sepref_import_param]:
  "(0, 0::uint64) ∈ Id"
  "(1, 1::uint64) ∈ Id"
  by (rule IdI)+

abbreviation uint64_nat_assn :: "nat ⇒ uint64 ⇒ assn" where
  ‹uint64_nat_assn ≡ pure uint64_nat_rel›


abbreviation uint64_assn :: ‹uint64 ⇒ uint64 ⇒ assn› where
  ‹uint64_assn ≡ id_assn›

lemma op_eq_uint64:
  ‹(uncurry (return oo ((=) :: uint64 ⇒ _)), uncurry (RETURN oo (=))) ∈
    uint64_assnk *a uint64_assnka bool_assn›
  by sepref_to_hoare sep_auto

lemma op_eq_uint64_nat[sepref_fr_rules]:
  ‹(uncurry (return oo ((=) :: uint64 ⇒ _)), uncurry (RETURN oo (=))) ∈
    uint64_nat_assnk *a uint64_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def)

lemma uint64_nat_assn_zero_uint64_nat[sepref_fr_rules]:
  ‹(uncurry0 (return 0), uncurry0 (RETURN zero_uint64_nat)) ∈ unit_assnka uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def)

lemma uint64_nat_assn_plus[sepref_fr_rules]:
  ‹(uncurry (return oo (+)), uncurry (RETURN oo (+))) ∈ [λ(m, n). m + n ≤ uint64_max]a
     uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def nat_of_uint64_add br_def)


lemma one_uint64_nat[sepref_fr_rules]:
  ‹(uncurry0 (return 1), uncurry0 (RETURN one_uint64_nat)) ∈ unit_assnka uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def)


lemma uint64_nat_assn_less[sepref_fr_rules]:
  ‹(uncurry (return oo (<)), uncurry (RETURN oo (<))) ∈
    uint64_nat_assnk *a uint64_nat_assnka bool_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def max_def
      nat_of_uint64_less_iff)

lemma mult_uint64[sepref_fr_rules]:
 ‹(uncurry (return oo (*) ), uncurry (RETURN oo (*)))
  ∈  uint64_assnk *a uint64_assnka uint64_assn›
  by sepref_to_hoare sep_auto

lemma shiftr_uint64[sepref_fr_rules]:
 ‹(uncurry (return oo (>>) ), uncurry (RETURN oo (>>)))
    ∈ uint64_assnk *a nat_assnka uint64_assn›
  by sepref_to_hoare sep_auto

text ‹
  Taken from theory @{theory Native_Word.Uint64}. We use real Word64 instead of the unbounded integer as
  done by default.

  Remark that all this setup is taken from @{theory Native_Word.Uint64}.
›
code_printing code_module "Uint64"  (SML) ‹(* Test that words can handle numbers between 0 and 63 *)
val _ = if 6 <= Word.wordSize then () else raise (Fail ("wordSize less than 6"));

structure Uint64 : sig
  eqtype uint64;
  val zero : uint64;
  val one : uint64;
  val fromInt : IntInf.int -> uint64;
  val toInt : uint64 -> IntInf.int;
  val toFixedInt : uint64 -> Int.int;
  val toLarge : uint64 -> LargeWord.word;
  val fromLarge : LargeWord.word -> uint64
  val fromFixedInt : Int.int -> uint64
  val plus : uint64 -> uint64 -> uint64;
  val minus : uint64 -> uint64 -> uint64;
  val times : uint64 -> uint64 -> uint64;
  val divide : uint64 -> uint64 -> uint64;
  val modulus : uint64 -> uint64 -> uint64;
  val negate : uint64 -> uint64;
  val less_eq : uint64 -> uint64 -> bool;
  val less : uint64 -> uint64 -> bool;
  val notb : uint64 -> uint64;
  val andb : uint64 -> uint64 -> uint64;
  val orb : uint64 -> uint64 -> uint64;
  val xorb : uint64 -> uint64 -> uint64;
  val shiftl : uint64 -> IntInf.int -> uint64;
  val shiftr : uint64 -> IntInf.int -> uint64;
  val shiftr_signed : uint64 -> IntInf.int -> uint64;
  val set_bit : uint64 -> IntInf.int -> bool -> uint64;
  val test_bit : uint64 -> IntInf.int -> bool;
end = struct

type uint64 = Word64.word;

val zero = (0wx0 : uint64);

val one = (0wx1 : uint64);

fun fromInt x = Word64.fromLargeInt (IntInf.toLarge x);

fun toInt x = IntInf.fromLarge (Word64.toLargeInt x);

fun toFixedInt x = Word64.toInt x;

fun fromLarge x = Word64.fromLarge x;

fun fromFixedInt x = Word64.fromInt x;

fun toLarge x = Word64.toLarge x;

fun plus x y = Word64.+(x, y);

fun minus x y = Word64.-(x, y);

fun negate x = Word64.~(x);

fun times x y = Word64.*(x, y);

fun divide x y = Word64.div(x, y);

fun modulus x y = Word64.mod(x, y);

fun less_eq x y = Word64.<=(x, y);

fun less x y = Word64.<(x, y);

fun set_bit x n b =
  let val mask = Word64.<< (0wx1, Word.fromLargeInt (IntInf.toLarge n))
  in if b then Word64.orb (x, mask)
     else Word64.andb (x, Word64.notb mask)
  end

fun shiftl x n =
  Word64.<< (x, Word.fromLargeInt (IntInf.toLarge n))

fun shiftr x n =
  Word64.>> (x, Word.fromLargeInt (IntInf.toLarge n))

fun shiftr_signed x n =
  Word64.~>> (x, Word.fromLargeInt (IntInf.toLarge n))

fun test_bit x n =
  Word64.andb (x, Word64.<< (0wx1, Word.fromLargeInt (IntInf.toLarge n))) <> Word64.fromInt 0

val notb = Word64.notb

fun andb x y = Word64.andb(x, y);

fun orb x y = Word64.orb(x, y);

fun xorb x y = Word64.xorb(x, y);

end (*struct Uint64*)
›

lemma bitAND_uint64_max_hnr[sepref_fr_rules]:
  ‹(uncurry (return oo  (AND)), uncurry (RETURN oo (AND)))
   ∈ [λ(a, b). a ≤ uint64_max ∧ b ≤ uint64_max]a
     uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_plus
      nat_of_uint64_and)


lemma two_uint64_nat[sepref_fr_rules]:
  ‹(uncurry0 (return 2), uncurry0 (RETURN two_uint64_nat))
   ∈  unit_assnka uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def)

lemma bitOR_uint64_max_hnr[sepref_fr_rules]:
  ‹(uncurry (return oo  (OR)), uncurry (RETURN oo (OR)))
   ∈ [λ(a, b). a ≤ uint64_max ∧ b ≤ uint64_max]a
     uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_plus
      nat_of_uint64_or)

lemma fast_minus_uint64_nat[sepref_fr_rules]:
  ‹(uncurry (return oo fast_minus), uncurry (RETURN oo fast_minus))
   ∈ [λ(a, b). a ≥ b]a uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by (sepref_to_hoare)
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_notle_minus
      nat_of_uint64_less_iff nat_of_uint64_le_iff)

lemma fast_minus_uint64[sepref_fr_rules]:
  ‹(uncurry (return oo fast_minus), uncurry (RETURN oo fast_minus))
   ∈ [λ(a, b). a ≥ b]a uint64_assnk *a uint64_assnk → uint64_assn›
  by (sepref_to_hoare)
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_notle_minus
      nat_of_uint64_less_iff nat_of_uint64_le_iff)

lemma minus_uint64_nat_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (-)), uncurry (RETURN oo (-))) ∈
    [λ(a, b). a ≥ b]a uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_ge_minus
   nat_of_uint64_le_iff)

lemma le_uint64_nat_assn_hnr[sepref_fr_rules]:
  ‹(uncurry (return oo (≤)), uncurry (RETURN oo (≤))) ∈ uint64_nat_assnk *a uint64_nat_assnka bool_assn›
  by sepref_to_hoare
   (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_le_iff)

lemma sum_mod_uint64_max_hnr[sepref_fr_rules]:
  ‹(uncurry (return oo  (+)), uncurry (RETURN oo sum_mod_uint64_max))
   ∈ uint64_nat_assnk *a uint64_nat_assnka uint64_nat_assn›
  apply sepref_to_hoare
  apply (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_plus
      sum_mod_uint64_max_def)
  done

lemma zero_uint64_hnr[sepref_fr_rules]:
  ‹(uncurry0 (return 0), uncurry0 (RETURN zero_uint64)) ∈ unit_assnka uint64_assn›
  by sepref_to_hoare (sep_auto simp: zero_uint64_def)


lemma zero_uint32_hnr[sepref_fr_rules]:
  ‹(uncurry0 (return 0), uncurry0 (RETURN zero_uint32)) ∈ unit_assnka uint32_assn›
  by sepref_to_hoare (sep_auto simp: zero_uint32_def)

lemma zero_uin64_hnr: ‹(uncurry0 (return 0), uncurry0 (RETURN 0)) ∈ unit_assnka uint64_assn›
  by sepref_to_hoare sep_auto

lemma two_uin64_hnr[sepref_fr_rules]:
  ‹(uncurry0 (return 2), uncurry0 (RETURN two_uint64)) ∈ unit_assnka uint64_assn›
  by sepref_to_hoare (sep_auto simp: two_uint64_def)

lemma two_uint32_hnr[sepref_fr_rules]:
  ‹(uncurry0 (return 2), uncurry0 (RETURN two_uint32)) ∈ unit_assnka uint32_assn›
  by sepref_to_hoare sep_auto

lemma sum_uint64_assn:
  ‹(uncurry (return oo (+)), uncurry (RETURN oo (+))) ∈ uint64_assnk *a uint64_assnka uint64_assn›
  by (sepref_to_hoare) sep_auto

lemma bitAND_uint64_nat_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (AND)), uncurry (RETURN oo (AND))) ∈
    uint64_nat_assnk *a uint64_nat_assnka uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_ao)

lemma bitAND_uint64_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (AND)), uncurry (RETURN oo (AND))) ∈
    uint64_assnk *a uint64_assnka uint64_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_ao)

lemma bitOR_uint64_nat_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (OR)), uncurry (RETURN oo (OR))) ∈
    uint64_nat_assnk *a uint64_nat_assnka uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_ao)

lemma bitOR_uint64_assn[sepref_fr_rules]:
  ‹(uncurry (return oo (OR)), uncurry (RETURN oo (OR))) ∈
    uint64_assnk *a uint64_assnka uint64_assn›
  by sepref_to_hoare
    (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_ao)

lemma nat_of_uint64_mult_le:
   ‹nat_of_uint64 ai * nat_of_uint64 bi ≤ uint64_max ⟹
       nat_of_uint64 (ai * bi) = nat_of_uint64 ai * nat_of_uint64 bi›
  apply transfer
  by (auto simp: unat_word_ariths uint64_max_def)

lemma uint64_nat_assn_mult:
  ‹(uncurry (return oo ((*))), uncurry (RETURN oo ((*)))) ∈ [λ(a, b). a * b ≤ uint64_max]a
      uint64_nat_assnk *a uint64_nat_assnk → uint64_nat_assn›
  by sepref_to_hoare
     (sep_auto simp: uint64_nat_rel_def br_def nat_of_uint64_mult_le)

lemma uint64_max_uint64_nat_assn:
 ‹(uncurry0 (return 18446744073709551615), uncurry0 (RETURN uint64_max)) ∈
  unit_assnka uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def uint64_max_def)

lemma uint64_max_nat_assn[sepref_fr_rules]:
 ‹(uncurry0 (return 18446744073709551615), uncurry0 (RETURN uint64_max)) ∈
  unit_assnka nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def uint64_max_def)


subsubsection ‹Conversions›

paragraph ‹From nat to 64 bits›

lemma uint64_of_nat_conv_hnr[sepref_fr_rules]:
  ‹(return o uint64_of_nat, RETURN o uint64_of_nat_conv) ∈
    [λn. n ≤ uint64_max]a nat_assnk → uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def uint64_of_nat_conv_def
      nat_of_uint64_uint64_of_nat_id)


paragraph ‹From nat to 32 bits›

lemma nat_of_uint32_spec_hnr[sepref_fr_rules]:
  ‹(return o uint32_of_nat, RETURN o nat_of_uint32_spec) ∈
     [λn. n ≤ uint32_max]a nat_assnk → uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_uint32_of_nat_id)


paragraph ‹From 64 to nat bits›

lemma nat_of_uint64_conv_hnr[sepref_fr_rules]:
  ‹(return o nat_of_uint64, RETURN o nat_of_uint64_conv) ∈ uint64_nat_assnka nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def)

lemma nat_of_uint64[sepref_fr_rules]:
  ‹(return o nat_of_uint64, RETURN o nat_of_uint64) ∈
    (uint64_assn)ka nat_assn›
  by sepref_to_hoare (sep_auto simp: uint64_nat_rel_def br_def
     nat_of_uint64_def
    split: option.splits)

paragraph ‹From 32 to nat bits›

lemma nat_of_uint32_conv_hnr[sepref_fr_rules]:
  ‹(return o nat_of_uint32, RETURN o nat_of_uint32_conv) ∈ uint32_nat_assnka nat_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def nat_of_uint32_conv_def)

lemma convert_to_uint32_hnr[sepref_fr_rules]:
  ‹(return o uint32_of_nat, RETURN o convert_to_uint32)
    ∈ [λn. n ≤ uint32_max]a nat_assnk → uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def uint32_max_def nat_of_uint32_uint32_of_nat_id)


paragraph ‹From 32 to 64 bits›

lemma uint64_of_uint32_hnr[sepref_fr_rules]:
  ‹(return o uint64_of_uint32, RETURN o uint64_of_uint32) ∈ uint32_assnka uint64_assn›
  by sepref_to_hoare (sep_auto simp: br_def)

lemma uint64_of_uint32_conv_hnr[sepref_fr_rules]:
  ‹(return o uint64_of_uint32, RETURN o uint64_of_uint32_conv) ∈
    uint32_nat_assnka uint64_nat_assn›
  by sepref_to_hoare (sep_auto simp: br_def uint32_nat_rel_def uint64_nat_rel_def
      nat_of_uint32_code nat_of_uint64_uint64_of_uint32)


paragraph ‹From 64 to 32 bits›

lemma uint32_of_uint64_conv_hnr[sepref_fr_rules]:
  ‹(return o uint32_of_uint64, RETURN o uint32_of_uint64_conv) ∈
     [λa. a ≤ uint32_max]a uint64_nat_assnk → uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_of_uint64_def uint32_nat_rel_def br_def nat_of_uint64_le_iff
      nat_of_uint32_uint32_of_nat_id uint64_nat_rel_def)


paragraph ‹From nat to 32 bits›


lemma (in -) uint32_of_nat[sepref_fr_rules]:
  ‹(return o uint32_of_nat, RETURN o uint32_of_nat) ∈ [λn. n ≤ uint32_max]a nat_assnk → uint32_assn›
  by sepref_to_hoare sep_auto

paragraph ‹Setup for numerals›

text ‹The refinement framework still defaults to \<^typ>‹nat›, making the constants
like \<^term>‹two_uint32_nat› still useful, but they can be omitted in some cases: For example, in
\<^term>‹2 + n›, \<^term>‹2 :: nat› will be refined to \<^typ>‹nat› (independently of \<^term>‹n›). However,
if the expression is \<^term>‹n + 2› and if  \<^term>‹n› is refined to \<^typ>‹uint32›, then everything will
work as one might expect.
›

lemmas [id_rules] =
  itypeI[Pure.of numeral "TYPE (num ⇒ uint32)"]
  itypeI[Pure.of numeral "TYPE (num ⇒ uint64)"]

lemma id_uint32_const[id_rules]: "(PR_CONST (a::uint32)) ::i TYPE(uint32)" by simp
lemma id_uint64_const[id_rules]: "(PR_CONST (a::uint64)) ::i TYPE(uint64)" by simp

lemma param_uint32_numeral[sepref_import_param]:
  ‹(numeral n, numeral n) ∈ uint32_rel›
  by auto

lemma param_uint64_numeral[sepref_import_param]:
  ‹(numeral n, numeral n) ∈ uint64_rel›
  by auto


(* TODO Move + is there a way to generate these constants on the fly? *)
locale nat_of_uint64_loc =
  fixes n :: num
  assumes le_uint64_max: ‹numeral n ≤ uint64_max›
begin

definition nat_of_uint64_numeral :: nat where
  [simp]: ‹nat_of_uint64_numeral = (numeral n)›

definition nat_of_uint64 :: uint64 where
 [simp]: ‹nat_of_uint64 = (numeral n)›

lemma nat_of_uint64_numeral_hnr:
  ‹(uncurry0 (return nat_of_uint64), uncurry0 (PR_CONST (RETURN nat_of_uint64_numeral)))
      ∈ unit_assnka uint64_nat_assn›
  using le_uint64_max
  by (sepref_to_hoare; sep_auto simp: uint64_nat_rel_def br_def uint64_max_def)
sepref_register nat_of_uint64_numeral
end

(* TODO a solution based on that, potentially with a simproc, would make wonders! *)
lemma (in -) [sepref_fr_rules]:
  ‹CONSTRAINT (λn. numeral n ≤ uint64_max) n ⟹
(uncurry0 (return (nat_of_uint64_loc.nat_of_uint64 n)),
     uncurry0 (RETURN (PR_CONST (nat_of_uint64_loc.nat_of_uint64_numeral n))))
   ∈  unit_assnka uint64_nat_assn›
  using nat_of_uint64_loc.nat_of_uint64_numeral_hnr[of n]
  by (auto simp: nat_of_uint64_loc_def)

lemma uint32_max_uint32_nat_assn:
  ‹(uncurry0 (return 4294967295), uncurry0 (RETURN uint32_max)) ∈ unit_assnka uint32_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_max_def uint32_nat_rel_def br_def)

lemma minus_uint64_assn:
 ‹(uncurry (return oo (-)), uncurry (RETURN oo (-))) ∈ uint64_assnk *a uint64_assnka uint64_assn›
 by sepref_to_hoare sep_auto

lemma uint32_of_nat_uint32_nat_assn[sepref_fr_rules]:
  ‹(return o id, RETURN o uint32_of_nat) ∈  uint32_nat_assnka uint32_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def)

lemma uint32_of_nat2[sepref_fr_rules]:
  ‹(return o uint32_of_uint64, RETURN o uint32_of_nat) ∈
    [λn. n ≤ uint32_max]a uint64_nat_assnk → uint32_assn›
  by sepref_to_hoare
    (sep_auto simp: uint32_nat_rel_def br_def uint64_nat_rel_def uint32_of_uint64_def)

lemma three_uint32_hnr:
  ‹(uncurry0 (return 3), uncurry0 (RETURN (three_uint32 :: uint32)) ) ∈ unit_assnka uint32_assn›
  by sepref_to_hoare (sep_auto simp: uint32_nat_rel_def br_def three_uint32_def)


lemma nat_of_uint64_id_conv_hnr[sepref_fr_rules]:
  ‹(return o id, RETURN o nat_of_uint64_id_conv) ∈ uint64_assnka uint64_nat_assn›
  by sepref_to_hoare
    (sep_auto simp: nat_of_uint64_id_conv_def uint64_nat_rel_def br_def)


end