Theory IsaSAT_Backtrack_LLVM

theory IsaSAT_Backtrack_LLVM
imports IsaSAT_Backtrack IsaSAT_VMTF_LLVM IsaSAT_Rephase_LLVM
theory IsaSAT_Backtrack_LLVM
  imports IsaSAT_Backtrack IsaSAT_VMTF_LLVM IsaSAT_Lookup_Conflict_LLVM
    IsaSAT_Rephase_LLVM
begin

lemma isa_empty_conflict_and_extract_clause_heur_alt_def:
    ‹isa_empty_conflict_and_extract_clause_heur M D outl = do {
     let C = replicate (length outl) (outl!0);
     (D, C, _) ← WHILET
         (λ(D, C, i). i < length_uint32_nat outl)
         (λ(D, C, i). do {
           ASSERT(i < length outl);
           ASSERT(i < length C);
           ASSERT(lookup_conflict_remove1_pre (outl ! i, D));
           let D = lookup_conflict_remove1 (outl ! i) D;
           let C = C[i := outl ! i];
	   ASSERT(get_level_pol_pre (M, C!i));
	   ASSERT(get_level_pol_pre (M, C!1));
	   ASSERT(1 < length C);
           let L1 = C!i;
           let L2 = C!1;
           let C = (if get_level_pol M L1 > get_level_pol M L2 then swap C 1 i else C);
           ASSERT(i+1 ≤ uint32_max);
           RETURN (D, C, i+1)
         })
        (D, C, 1);
     ASSERT(length outl ≠ 1 ⟶ length C > 1);
     ASSERT(length outl ≠ 1 ⟶  get_level_pol_pre (M, C!1));
     RETURN ((True, D), C, if length outl = 1 then 0 else get_level_pol M (C!1))
  }›
  unfolding isa_empty_conflict_and_extract_clause_heur_def (*WB_More_Refinement_List.swap_def
    swap_def[symmetric]*)
  by auto

sepref_def empty_conflict_and_extract_clause_heur_fast_code
  is ‹uncurry2 (isa_empty_conflict_and_extract_clause_heur)›
  :: ‹[λ((M, D), outl). outl ≠ [] ∧ length outl ≤ uint32_max]a
      trail_pol_fast_assnk *a lookup_clause_rel_assnd *a out_learned_assnk →
       (conflict_option_rel_assn) ×a clause_ll_assn ×a uint32_nat_assn›
  supply [[goals_limit=1]] image_image[simp]
  supply [simp] = max_snat_def uint32_max_def
  unfolding isa_empty_conflict_and_extract_clause_heur_alt_def
    larray_fold_custom_replicate length_uint32_nat_def conflict_option_rel_assn_def
  apply (rewrite at ‹⌑› in ‹_ !1› snat_const_fold[where 'a=64])+
  apply (rewrite at ‹⌑› in ‹_ !0› snat_const_fold[where 'a=64])
  apply (rewrite at ‹swap _ ⌑ _› snat_const_fold[where 'a=64])
  apply (rewrite at ‹⌑› in ‹(_, _, _ + 1)› snat_const_fold[where 'a=64])
  apply (rewrite at ‹⌑› in ‹(_, _, 1)› snat_const_fold[where 'a=64])
  apply (rewrite at ‹⌑› in ‹If (length _ = ⌑)› snat_const_fold[where 'a=64])
  apply (annot_unat_const "TYPE(32)")
  unfolding gen_swap convert_swap
  by sepref


lemma emptied_list_alt_def: ‹emptied_list xs = take 0 xs›
  by (auto simp: emptied_list_def)

sepref_def empty_cach_code
  is ‹empty_cach_ref_set›
  :: ‹cach_refinement_l_assnda cach_refinement_l_assn›
  supply [[goals_limit=1]]
  unfolding empty_cach_ref_set_def comp_def cach_refinement_l_assn_def emptied_list_alt_def
  apply (annot_snat_const "TYPE(64)")
  apply (rewrite at ‹_[⌑ := SEEN_UNKNOWN]› value_of_atm_def[symmetric])
  apply (rewrite at ‹_[⌑ := SEEN_UNKNOWN]› index_of_atm_def[symmetric])
  by sepref



theorem empty_cach_code_empty_cach_ref[sepref_fr_rules]:
  ‹(empty_cach_code, RETURN ∘ empty_cach_ref)
    ∈ [empty_cach_ref_pre]a
    cach_refinement_l_assnd → cach_refinement_l_assn›
  (is ‹?c ∈ [?pre]a ?im → ?f›)
proof -
  have H: ‹?c
    ∈[comp_PRE Id
     (λ(cach, supp).
         (∀L∈set supp. L < length cach) ∧
         length supp ≤ Suc (uint32_max div 2) ∧
         (∀L<length cach. cach ! L ≠ SEEN_UNKNOWN ⟶ L ∈ set supp))
     (λx y. True)
     (λx. nofail ((RETURN ∘ empty_cach_ref) x))]a
      hrp_comp (cach_refinement_l_assnd)
                     Id → hr_comp cach_refinement_l_assn Id›
    (is ‹_ ∈ [?pre']a ?im' → ?f'›)
    using hfref_compI_PRE[OF empty_cach_code.refine[unfolded PR_CONST_def convert_fref]
        empty_cach_ref_set_empty_cach_ref[unfolded convert_fref]] by simp
  have pre: ‹?pre' h x› if ‹?pre x› for x h
    using that by (auto simp: comp_PRE_def trail_pol_def
        ann_lits_split_reasons_def empty_cach_ref_pre_def)
  have im: ‹?im' = ?im›
    by simp
  have f: ‹?f' = ?f›
    by auto
  show ?thesis
    apply (rule hfref_weaken_pre[OF ])
     defer
    using H unfolding im f apply assumption
    using pre ..
qed

sepref_register fm_add_new_fast

lemma isasat_fast_length_leD: ‹isasat_fast S ⟹ Suc (length (get_clauses_wl_heur S)) < max_snat 64›
  by (cases S) (auto simp: isasat_fast_def max_snat_def sint64_max_def)

sepref_register update_heuristics
sepref_def update_heuristics_impl
  is [llvm_inline,sepref_fr_rules] ‹uncurry (RETURN oo update_heuristics)›
  :: ‹uint32_nat_assnk *a heuristic_assnda heuristic_assn›
  unfolding update_heuristics_def heuristic_assn_def
  by sepref

sepref_register cons_trail_Propagated_tr
sepref_def propagate_unit_bt_wl_D_fast_code
  is ‹uncurry propagate_unit_bt_wl_D_int›
  :: ‹unat_lit_assnk *a isasat_bounded_assnda isasat_bounded_assn›
  supply [[goals_limit = 1]] vmtf_flush_def[simp] image_image[simp] uminus_𝒜in_iff[simp]
  unfolding propagate_unit_bt_wl_D_int_def isasat_bounded_assn_def
    PR_CONST_def
  unfolding fold_tuple_optimizations
  apply (annot_snat_const "TYPE(64)")
  by sepref


sepref_def propagate_bt_wl_D_fast_codeXX
  is ‹uncurry2 propagate_bt_wl_D_heur›
  :: ‹[λ((L, C), S). isasat_fast S]a
      unat_lit_assnk *a clause_ll_assnk *a isasat_bounded_assnd → isasat_bounded_assn›

  supply [[goals_limit = 1]] append_ll_def[simp] isasat_fast_length_leD[dest]
    propagate_bt_wl_D_fast_code_isasat_fastI2[intro] length_ll_def[simp]
    propagate_bt_wl_D_fast_code_isasat_fastI3[intro]
  unfolding propagate_bt_wl_D_heur_alt_def
    isasat_bounded_assn_def
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
    append_ll_def[symmetric] append_ll_def[symmetric]
    PR_CONST_def save_phase_def
  apply (rewrite in ‹add_lbd (of_nat ⌑) _› annot_unat_unat_upcast[where 'l=64])
  apply (rewrite in ‹(_ + ⌑, _)› unat_const_fold[where 'a=64])
  apply (rewrite at ‹RETURN (_, _, _, _, _, _, ⌑, _)› unat_const_fold[where 'a=32])
  apply (annot_snat_const "TYPE(64)")
  unfolding fold_tuple_optimizations
  apply (rewrite in ‹isasat_fast ⌑› fold_tuple_optimizations[symmetric])+
  by sepref

lemma extract_shorter_conflict_list_heur_st_alt_def:
    ‹extract_shorter_conflict_list_heur_st = (λ(M, N, (bD), Q', W', vm, clvls, cach, lbd, outl,
       stats, ccont, vdom). do {
     let D =  the_lookup_conflict bD;
     ASSERT(fst M ≠ []);
     let K = lit_of_last_trail_pol M;
     ASSERT(0 < length outl);
     ASSERT(lookup_conflict_remove1_pre (-K, D));
     let D = lookup_conflict_remove1 (-K) D;
     let outl = outl[0 := -K];
     vm ← isa_vmtf_mark_to_rescore_also_reasons M N outl vm;
     (D, cach, outl) ← isa_minimize_and_extract_highest_lookup_conflict M N D cach lbd outl;
     ASSERT(empty_cach_ref_pre cach);
     let cach = empty_cach_ref cach;
     ASSERT(outl ≠ [] ∧ length outl ≤ uint32_max);
     (D, C, n) ← isa_empty_conflict_and_extract_clause_heur M D outl;
     RETURN ((M, N, D, Q', W', vm, clvls, cach, lbd, take 1 outl, stats, ccont, vdom), n, C)
  })›
  unfolding extract_shorter_conflict_list_heur_st_def
  by (auto simp: the_lookup_conflict_def Let_def intro!: ext)

sepref_register isa_minimize_and_extract_highest_lookup_conflict
  empty_conflict_and_extract_clause_heur

sepref_def extract_shorter_conflict_list_heur_st_fast
  is ‹extract_shorter_conflict_list_heur_st›
  :: ‹[λS. length (get_clauses_wl_heur S) ≤ sint64_max]a
        isasat_bounded_assnd → isasat_bounded_assn ×a uint32_nat_assn ×a clause_ll_assn›
  supply [[goals_limit=1]] empty_conflict_and_extract_clause_pre_def[simp]
  unfolding extract_shorter_conflict_list_heur_st_alt_def PR_CONST_def isasat_bounded_assn_def
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
    fold_tuple_optimizations
  apply (annot_snat_const "TYPE(64)")
  by sepref


sepref_register find_lit_of_max_level_wl
  extract_shorter_conflict_list_heur_st lit_of_hd_trail_st_heur propagate_bt_wl_D_heur
  propagate_unit_bt_wl_D_int
sepref_register backtrack_wl

sepref_def lit_of_hd_trail_st_heur_fast_code
  is ‹lit_of_hd_trail_st_heur›
  :: ‹[λS. True]a isasat_bounded_assnk → unat_lit_assn›
  unfolding lit_of_hd_trail_st_heur_alt_def isasat_bounded_assn_def
  by sepref

sepref_register save_phase_st
sepref_def backtrack_wl_D_fast_code
  is ‹backtrack_wl_D_nlit_heur›
  :: ‹[isasat_fast]a isasat_bounded_assnd → isasat_bounded_assn›
  supply [[goals_limit=1]]
    size_conflict_wl_def[simp] isasat_fast_length_leD[intro] isasat_fast_def[simp]
  unfolding backtrack_wl_D_nlit_heur_def PR_CONST_def
  unfolding delete_index_and_swap_update_def[symmetric] append_update_def[symmetric]
    append_ll_def[symmetric]
    size_conflict_wl_def[symmetric]
  apply (annot_snat_const "TYPE(64)")
  by sepref

(* TODO: Move *)
lemmas [llvm_inline] = add_lbd_def

experiment
begin

  export_llvm
    empty_conflict_and_extract_clause_heur_fast_code
    empty_cach_code
    propagate_bt_wl_D_fast_codeXX
    propagate_unit_bt_wl_D_fast_code
    extract_shorter_conflict_list_heur_st_fast
    lit_of_hd_trail_st_heur_fast_code
    backtrack_wl_D_fast_code

end


end