Theory Watched_Literals_Watch_List_Enumeration

theory Watched_Literals_Watch_List_Enumeration
imports Watched_Literals_List_Enumeration Watched_Literals_Watch_List
theory Watched_Literals_Watch_List_Enumeration
  imports Watched_Literals_List_Enumeration Watched_Literals.Watched_Literals_Watch_List
begin

definition find_decomp_target_wl :: ‹nat ⇒ 'v twl_st_wl ⇒ ('v twl_st_wl × 'v literal) nres› where
  ‹find_decomp_target_wl =  (λi S.
    SPEC(λ(T, K). ∃M2 M1. equality_except_trail_wl S T ∧ get_trail_wl T = M1 ∧
       (Decided K # M1, M2) ∈ set (get_all_ann_decomposition (get_trail_wl S)) ∧
          get_level (get_trail_wl S) K = i))›

fun propagate_unit_and_add_wl :: ‹'v literal ⇒ 'v twl_st_wl ⇒ 'v twl_st_wl› where
  ‹propagate_unit_and_add_wl K (M, N, D, NE, UE, Q, W) =
      (Propagated (-K) 0 # M, N, None, add_mset {#-K#} NE, UE, {#K#}, W)›

definition negate_mode_bj_unit_wl   :: ‹'v twl_st_wl ⇒ 'v twl_st_wl nres› where
‹negate_mode_bj_unit_wl = (λS. do {
    (S, K) ← find_decomp_target_wl 1 S;
    ASSERT(K ∈# all_lits_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl S) +
           get_unit_clauses_wl S));
    RETURN (propagate_unit_and_add_wl K S)
  })›

abbreviation find_decomp_target_wl_ref where
  ‹find_decomp_target_wl_ref S ≡
     {((T, K), (T', K')). (T, T') ∈ {(T, T'). (T, T') ∈ state_wl_l None ∧ correct_watching T} ∧
        (K , K') ∈ Id ∧
        K ∈# all_lits_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl T) +
           get_unit_clauses_wl T) ∧
        K ∈# all_lits_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl T) +
           get_unit_init_clss_wl T) ∧ equality_except_trail_wl S T ∧
        atms_of (DECO_clause (get_trail_wl S)) ⊆ atms_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl T) +
          get_unit_init_clss_wl T) ∧ distinct_mset (DECO_clause (get_trail_wl S)) ∧
        correct_watching T}›

lemma DECO_clause_nil[simp]: ‹DECO_clause [] = {#}›
  by (auto simp: DECO_clause_def)

lemma in_DECO_clauseD: ‹x ∈# DECO_clause M ⟹ -x ∈ lits_of_l M›
  by (auto simp: DECO_clause_def lits_of_def)

lemma in_atms_of_DECO_clauseD: ‹x ∈ atms_of (DECO_clause M) ⟹ x ∈ atm_of ` (lits_of_l M)›
  by (auto simp: DECO_clause_def lits_of_def atms_of_def)

lemma no_dup_distinct_mset_DECO_clause:
  assumes ‹no_dup M›
  shows ‹distinct_mset (DECO_clause M)›
proof -
  have ‹distinct (map lit_of (filter is_decided M))›
    using no_dup_map_lit_of[OF assms] distinct_map_filter by blast
  moreover have ‹?thesis ⟷ distinct (map lit_of (filter is_decided M))›
    unfolding DECO_clause_def image_mset.compositionality[symmetric]
    apply (subst distinct_image_mset_inj)
    subgoal by (auto simp: inj_on_def)
    subgoal by (auto simp flip: mset_filter
      distinct_mset_mset_distinct simp del: mset_filter)
    done
  ultimately show ?thesis by blast
qed

lemma find_decomp_target_wl_find_decomp_target_l:
  assumes
    SS': ‹(S, S') ∈ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}› and
    inv: ‹∃S'' b. (S', S'') ∈ twl_st_l b ∧ twl_struct_invs S''› and
    [simp]: ‹a = a'›
  shows ‹find_decomp_target_wl a S ≤
     ⇓ (find_decomp_target_wl_ref S) (find_decomp_target a' S')›
    (is ‹_ ≤ ⇓ ?negate _›)
proof -
  let ?y0 = ‹λS S'. (λ(M, Oth). (get_trail_wl S, Oth)) S'›
  have K: ‹⋀K. K ∈ lits_of_l (get_trail_wl S) ⟹
     K ∈# all_lits_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl S) +
          get_unit_init_clss_wl S)› (is ‹⋀K. ?HK K ⟹ ?K K›) and
    DECO:
      ‹atms_of (DECO_clause (get_trail_wl S)) ⊆ atms_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl S) +
          get_unit_init_clss_wl S)› (is ?DECO) and
    distinct_DECO:
      ‹distinct_mset (DECO_clause (get_trail_wl S))› (is ?dist_DECO)
  proof -
    obtain b S'' where
      S'_S'': ‹(S', S'') ∈ twl_st_l b› and
      struct: ‹twl_struct_invs S''›
      using inv unfolding negate_mode_bj_unit_l_inv_def by blast
    then have no_alien: ‹cdclW_restart_mset.cdclW_all_struct_inv (stateW_of S'')›
      using struct unfolding twl_struct_invs_def by fast
    then have no_alien: ‹cdclW_restart_mset.no_strange_atm (stateW_of S'')› and
      M_lev: ‹cdclW_restart_mset.cdclW_M_level_inv (stateW_of S'')›
      unfolding cdclW_restart_mset.cdclW_all_struct_inv_def by fast+
    moreover have ‹atms_of_mm (get_all_init_clss S'') =
          atms_of_mm (mset `# (ran_mf (get_clauses_wl S)) + get_unit_init_clss_wl S)›
      apply (subst all_clss_lf_ran_m[symmetric])
      using no_alien
      using S'_S'' SS' unfolding cdclW_restart_mset.no_strange_atm_def
      by (cases S; cases S'; cases b)
        (auto simp: mset_take_mset_drop_mset' cdclW_restart_mset_state
        in_all_lits_of_mm_ain_atms_of_iff twl_st_l_def state_wl_l_def)
    ultimately show ‹⋀K. ?HK K ⟹ ?K K›
      using S'_S'' SS' unfolding cdclW_restart_mset.no_strange_atm_def
      by (auto 5 5 simp: twl_st_l twl_st mset_take_mset_drop_mset'
        in_all_lits_of_mm_ain_atms_of_iff get_unit_clauses_wl_alt_def)
    then show ?DECO
      using S'_S'' SS' unfolding cdclW_restart_mset.no_strange_atm_def
      by (auto simp: twl_st_l twl_st mset_take_mset_drop_mset'
        in_all_lits_of_mm_ain_atms_of_iff get_unit_clauses_wl_alt_def
        dest: in_atms_of_DECO_clauseD)

    show ?dist_DECO
      by (rule no_dup_distinct_mset_DECO_clause)
       (use M_lev S'_S'' SS' in ‹auto simp: cdclW_restart_mset.cdclW_M_level_inv_def twl_st›)
  qed

  show ?thesis
    using SS'
    unfolding find_decomp_target_wl_def find_decomp_target_def apply -
    apply (rule RES_refine)
    apply (rule_tac x=‹(?y0 (fst s) S', snd s)› in bexI)
    subgoal
      using K DECO distinct_DECO
      by (cases S; cases S')
       (force simp: state_wl_l_def correct_watching.simps clause_to_update_def
          mset_take_mset_drop_mset' all_lits_of_mm_union
          dest!: get_all_ann_decomposition_exists_prepend)+
    subgoal
      by (cases S; cases S')
        (auto simp: state_wl_l_def correct_watching.simps clause_to_update_def)
    done
qed

lemma negate_mode_bj_unit_wl_negate_mode_bj_unit_l:
  fixes S :: ‹'v twl_st_wl› and S' :: ‹'v twl_st_l›
  assumes ‹count_decided (get_trail_wl S) = 1› and
    SS': ‹(S, S') ∈ {(S, S'). (S, S') ∈ state_wl_l None ∧ correct_watching S}›
  shows
    ‹negate_mode_bj_unit_wl S ≤ ⇓{(S, S'). (S, S') ∈ state_wl_l None ∧ correct_watching S}
       (negate_mode_bj_unit_l S')›
       (is ‹_ ≤ ⇓ ?R _›)
proof -
  have 2: ‹(propagate_unit_and_add_wl x2a x1a, propagate_unit_and_add_l x2 x1)
        ∈ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}›
    if
      ‹(x, x') ∈ find_decomp_target_wl_ref S› and
      ‹x' = (x1, x2)› and
      ‹x = (x1a, x2a)›
    for x2a x1a x2 x1 and x :: ‹'v twl_st_wl × 'v literal› and x' :: ‹'v twl_st_l × 'v literal›
  proof -
    show ?thesis
      using that
      by (cases x1a; cases x1)
        (auto, auto simp: state_wl_l_def correct_watching.simps clause_to_update_def
          all_lits_of_mm_add_mset
          all_lits_of_m_add_mset all_lits_of_mm_union mset_take_mset_drop_mset'
          dest: in_all_lits_of_mm_uminusD)
  qed

  show ?thesis
    using SS' unfolding negate_mode_bj_unit_wl_def negate_mode_bj_unit_l_def
    apply (refine_rcg find_decomp_target_wl_find_decomp_target_l 2)
    subgoal unfolding negate_mode_bj_unit_l_inv_def by blast
    subgoal unfolding negate_mode_bj_unit_l_inv_def by blast
    subgoal by blast
    apply assumption+
    done
qed

definition propagate_nonunit_and_add_wl_pre
  :: ‹'v literal ⇒ 'v clause_l ⇒ nat ⇒ 'v twl_st_wl ⇒ bool› where
  ‹propagate_nonunit_and_add_wl_pre K C i S ⟷
     length C ≥ 2 ∧ i > 0 ∧ i ∉# dom_m (get_clauses_wl S) ∧
     atms_of (mset C) ⊆ atms_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl S) +
          get_unit_init_clss_wl S)›

fun propagate_nonunit_and_add_wl
  :: ‹'v literal ⇒ 'v clause_l ⇒ nat ⇒ 'v twl_st_wl ⇒ 'v twl_st_wl nres›
where
  ‹propagate_nonunit_and_add_wl K C i (M, N, D, NE, UE, Q, W) = do {
      ASSERT(propagate_nonunit_and_add_wl_pre K C i (M, N, D, NE, UE, Q, W));
      let b = (length C = 2);
      let W = W(C!0 := W (C!0) @ [(i, C!1, b)]);
      let W = W(C!1 := W (C!1) @ [(i, C!0, b)]);
      RETURN (Propagated (-K) i # M, fmupd i (C, True) N, None,
      NE, UE, {#K#}, W)
    }›

lemma twl_st_l_splitD:
  ‹(⋀M N D NE UE Q W. f (M, N, D, NE, UE, Q, W) = P M N D NE UE Q W) ⟹
   f S = P (get_trail_l S) (get_clauses_l S) (get_conflict_l S) (get_unit_init_clauses_l S)
    (get_unit_learned_clauses_l S) (clauses_to_update_l S) (literals_to_update_l S)›
  by (cases S) auto

lemma twl_st_wl_splitD:
  ‹(⋀M N D NE UE Q W. f (M, N, D, NE, UE, Q, W) = P M N D NE UE Q W) ⟹
   f S = P (get_trail_wl S) (get_clauses_wl S) (get_conflict_wl S) (get_unit_init_clss_wl S)
    (get_unit_learned_clss_wl S) (literals_to_update_wl S) (get_watched_wl S)›
  by (cases S) auto

definition negate_mode_bj_nonunit_wl_inv where
‹negate_mode_bj_nonunit_wl_inv S ⟷
   (∃S'' b. (S, S'') ∈ state_wl_l b ∧ negate_mode_bj_nonunit_l_inv S'' ∧ correct_watching S)›

definition negate_mode_bj_nonunit_wl :: ‹'v twl_st_wl ⇒ 'v twl_st_wl nres› where
‹negate_mode_bj_nonunit_wl = (λS. do {
    ASSERT(negate_mode_bj_nonunit_wl_inv S);
    let C = DECO_clause_l (get_trail_wl S);
    (S, K) ← find_decomp_target_wl (count_decided (get_trail_wl S)) S;
    i ← get_fresh_index_wl (get_clauses_wl S) (get_unit_clauses_wl S) (get_watched_wl S);
    propagate_nonunit_and_add_wl K C i S
  })›


lemmas propagate_nonunit_and_add_wl_def =
   twl_st_wl_splitD[of ‹propagate_nonunit_and_add_wl _ _ _›, OF propagate_nonunit_and_add_wl.simps]

lemmas propagate_nonunit_and_add_l_def =
   twl_st_l_splitD[of ‹propagate_nonunit_and_add_l _ _ _›, OF propagate_nonunit_and_add_l.simps,
  rule_format]

lemma atms_of_subset_in_atms_ofI:
  ‹atms_of C ⊆ atms_of_ms N ⟹ L ∈# C ⟹ atm_of L ∈ atms_of_ms N›
  by (auto dest!: multi_member_split)

lemma in_DECO_clause_l_in_DECO_clause_iff:
  ‹x ∈ set (DECO_clause_l M) ⟷ x ∈# (DECO_clause M)›
  by (metis DECO_clause_l_DECO_clause set_mset_mset)

lemma distinct_DECO_clause_l:
  ‹no_dup M ⟹ distinct (DECO_clause_l M)›
  by (auto simp: DECO_clause_l_def distinct_map inj_on_def
      dest!: no_dup_map_lit_of)

lemma propagate_nonunit_and_add_wl_propagate_nonunit_and_add_l:
  assumes
    SS': ‹(S, S') ∈ state_wl_l None› and
    inv: ‹negate_mode_bj_nonunit_wl_inv S› and
    TK: ‹(TK, TK') ∈ find_decomp_target_wl_ref S› and
    [simp]: ‹TK' = (T, K)› and
    [simp]: ‹TK = (T', K')› and
    ij: ‹(i, j) ∈ {(i, j). i = j ∧ i ∉# dom_m (get_clauses_wl T') ∧ i > 0 ∧
       (∀L ∈# all_lits_of_mm (mset `# ran_mf (get_clauses_wl T') + get_unit_clauses_wl T') .
          i ∉ fst ` set (watched_by T' L))}›
  shows ‹propagate_nonunit_and_add_wl K' (DECO_clause_l (get_trail_wl S)) i T'
         ≤ SPEC (λc. (c, propagate_nonunit_and_add_l K
                          (DECO_clause_l (get_trail_l S')) j T)
                     ∈ {(S, S'').
                        (S, S'') ∈ state_wl_l None ∧ correct_watching S})›
proof -
  have [simp]: ‹i = j› and j: ‹j ∉# dom_m (get_clauses_wl T')›
    using ij by auto
  have [simp]: ‹DECO_clause_l (get_trail_l S') = DECO_clause_l (get_trail_wl S)›
    using SS' by auto
  obtain T U b b' where
    ST: ‹(S, T) ∈ state_wl_l b› and
    corr: ‹correct_watching S› and
    TU: ‹(T, U) ∈ twl_st_l b'› and
    ‹twl_list_invs T› and
    ge1: ‹1 < count_decided (get_trail_l T)› and
    st: ‹twl_struct_invs U› and
    ‹twl_stgy_invs U› and
    ‹get_conflict_l T = None›
    using inv unfolding negate_mode_bj_nonunit_wl_inv_def negate_mode_bj_nonunit_l_inv_def apply -
    by blast
  have ‹length (DECO_clause_l (get_trail_wl S)) > 1›
    using ST ge1 by auto
  then have 1: ‹DECO_clause_l (get_trail_wl S) =
        DECO_clause_l (get_trail_wl S) ! 0 #
           DECO_clause_l (get_trail_wl S) ! Suc 0 # drop 2 (DECO_clause_l (get_trail_wl S))›
    by (cases ‹DECO_clause_l (get_trail_wl S)›; cases ‹tl (DECO_clause_l (get_trail_wl S))›)
      auto
  have ‹no_dup (trail (stateW_of U))›
    using st unfolding twl_struct_invs_def cdclW_restart_mset.cdclW_all_struct_inv_def
      cdclW_restart_mset.cdclW_M_level_inv_def
    by fast
  then have neq: False if ‹DECO_clause_l (get_trail_wl S) ! 0 = DECO_clause_l (get_trail_wl S) ! Suc 0›
    using that
    apply (subst (asm) nth_eq_iff_index_eq)
    using ge1 ST TU by (auto simp: twl_st twl_st_l twl_st_wl distinct_DECO_clause_l)

  show ?thesis
    using TK j corr ge1 ST
    apply (simp only: propagate_nonunit_and_add_wl_def
       propagate_nonunit_and_add_l_def Let_def
       assert_bind_spec_conv)
    apply (intro conjI)
    subgoal using j ij TK unfolding propagate_nonunit_and_add_wl_pre_def by auto
    subgoal
      unfolding RETURN_def less_eq_nres.simps mem_Collect_eq prod.simps singleton_iff
      apply (subst subset_iff)
      unfolding RETURN_def less_eq_nres.simps mem_Collect_eq prod.simps singleton_iff
      apply (intro conjI impI allI)
      subgoal by (auto simp: state_wl_l_def)
      subgoal
        apply (simp only: )
        apply (subst 1)
        apply (subst One_nat_def[symmetric])+
        apply (subst fun_upd_other)
        subgoal
          using SS' length_DECO_clause_l[of ‹get_trail_wl S›]
          by (cases ‹DECO_clause_l (get_trail_wl S)›; cases ‹tl (DECO_clause_l (get_trail_wl S))›)
            (auto simp: DECO_clause_l_DECO_clause[symmetric] twl_st_l twl_st
            simp del: DECO_clause_l_DECO_clause)
        apply (rule correct_watching_learn[THEN iffD2])
        apply (rule atms_of_subset_in_atms_ofI[of ‹DECO_clause (get_trail_wl S)›])
        subgoal by (auto simp add: mset_take_mset_drop_mset' get_unit_clauses_wl_alt_def
          DECO_clause_l_DECO_clause[symmetric]
           simp del: DECO_clause_l_DECO_clause)
        subgoal by (solves ‹auto simp add: mset_take_mset_drop_mset'
          DECO_clause_l_DECO_clause[symmetric]
           simp del: DECO_clause_l_DECO_clause›)
        subgoal apply (use in ‹auto simp add: mset_take_mset_drop_mset' DECO_clause_l_DECO_clause[symmetric]
           simp del: DECO_clause_l_DECO_clause›)
          by (metis (no_types, lifting) "1" UnE add_mset_commute image_eqI mset.simps(2)
              set_mset_mset subsetCE union_single_eq_member)
        subgoal ― ‹TODO Proof›
         apply (auto simp: mset_take_mset_drop_mset' in_DECO_clause_l_in_DECO_clause_iff
           dest!: in_set_dropD)
           by (metis UnE atms_of_ms_union atms_of_subset_in_atms_ofI)
        subgoal by simp
        subgoal using corr ij
          by (cases S; cases T; cases T')
            (auto simp: equality_except_trail_wl.simps state_wl_l_def correct_watching.simps
             clause_to_update_def)
        subgoal using corr neq
          by (cases S; cases T; cases T')
           (auto simp: equality_except_trail_wl.simps state_wl_l_def correct_watching.simps
             clause_to_update_def)
        subgoal
          by (subst 1) auto
        subgoal using corr
          by (cases S; cases T; cases T')
           (auto simp: equality_except_trail_wl.simps state_wl_l_def correct_watching.simps
             clause_to_update_def)
        done
      done
    done
  qed

lemma watched_by_alt_def:
  ‹watched_by T L = get_watched_wl T L›
  by (cases T) auto

lemma negate_mode_bj_nonunit_wl_negate_mode_bj_nonunit_l:
  fixes S :: ‹'v twl_st_wl› and S' :: ‹'v twl_st_l›
  assumes
    SS': ‹(S, S') ∈ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}›
  shows
    ‹negate_mode_bj_nonunit_wl S ≤ ⇓{(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}
       (negate_mode_bj_nonunit_l S')›
proof -
  have fresh: ‹get_fresh_index_wl (get_clauses_wl T) (get_unit_clauses_wl T) (get_watched_wl T)
    ≤ ⇓ {(i, j). i = j ∧ i ∉# dom_m (get_clauses_wl T) ∧ i > 0 ∧
       (∀L ∈# all_lits_of_mm (mset `# ran_mf (get_clauses_wl T) + get_unit_clauses_wl T) .
          i ∉ fst ` set (watched_by T L))}
        (get_fresh_index (get_clauses_l T'))›
    if ‹(TK, TK') ∈ find_decomp_target_wl_ref S› and
      ‹TK = (T, K)› and
      ‹TK' =(T', K')›
    for T T' K K' TK TK'
    using that by (auto simp: get_fresh_index_def equality_except_trail_wl_get_clauses_wl
        get_fresh_index_wl_def watched_by_alt_def
      intro!: RES_refine)
  show ?thesis
    using SS'
    unfolding negate_mode_bj_nonunit_wl_def negate_mode_bj_nonunit_l_def
    apply (refine_rcg find_decomp_target_wl_find_decomp_target_l fresh
      propagate_nonunit_and_add_wl_propagate_nonunit_and_add_l)
    subgoal
       using SS' unfolding negate_mode_bj_unit_l_inv_def negate_mode_bj_nonunit_wl_inv_def
       by blast
    subgoal
       using SS' unfolding negate_mode_bj_nonunit_l_inv_def by blast
    subgoal using SS' by (auto simp add: twl_st_wl)
    apply assumption+
    apply (auto simp add: equality_except_trail_wl_get_clauses_wl)
    done
qed

definition negate_mode_restart_nonunit_wl_inv :: ‹'v twl_st_wl ⇒ bool› where
‹negate_mode_restart_nonunit_wl_inv S ⟷
  (∃S' b. (S, S') ∈ state_wl_l b ∧ negate_mode_restart_nonunit_l_inv S' ∧ correct_watching S)›

definition restart_nonunit_and_add_wl_inv where
  ‹restart_nonunit_and_add_wl_inv C i S ⟷
     length C ≥ 2 ∧ correct_watching S ∧
      atms_of (mset C) ⊆ atms_of_mm (clause `# twl_clause_of `# ran_mf (get_clauses_wl S) +
          get_unit_init_clss_wl S)›

fun restart_nonunit_and_add_wl :: ‹'v clause_l ⇒ nat ⇒ 'v twl_st_wl ⇒ 'v twl_st_wl nres› where
  ‹restart_nonunit_and_add_wl C i (M, N, D, NE, UE, Q, W) = do {
      ASSERT(restart_nonunit_and_add_wl_inv C i (M, N, D, NE, UE, Q, W));
     let b = (length C = 2);
      let W = W(C!0 := W (C!0) @ [(i, C!1, b)]);
      let W = W(C!1 := W (C!1) @ [(i, C!0, b)]);
      RETURN (M, fmupd i (C, True) N, None, NE, UE, {#}, W)
  }›

definition negate_mode_restart_nonunit_wl :: ‹'v twl_st_wl ⇒ 'v twl_st_wl nres› where
‹negate_mode_restart_nonunit_wl = (λS. do {
    ASSERT(negate_mode_restart_nonunit_wl_inv S);
    let C = DECO_clause_l (get_trail_wl S);
    i ← SPEC(λi. i < count_decided (get_trail_wl S));
    (S, K) ← find_decomp_target_wl i S;
    i ← get_fresh_index_wl (get_clauses_wl S) (get_unit_clauses_wl S) (get_watched_wl S);
    restart_nonunit_and_add_wl C i S
  })›


definition negate_mode_wl_inv where
  ‹negate_mode_wl_inv S ⟷
     (∃S' b. (S, S') ∈ state_wl_l b ∧ negate_mode_l_inv S' ∧ correct_watching S)›

definition negate_mode_wl :: ‹'v twl_st_wl ⇒ 'v twl_st_wl nres› where
  ‹negate_mode_wl S = do {
    ASSERT(negate_mode_wl_inv S);
    if count_decided (get_trail_wl S) = 1
    then negate_mode_bj_unit_wl S
    else do {
      b ← SPEC(λ_. True);
      if b then negate_mode_bj_nonunit_wl S else negate_mode_restart_nonunit_wl S
    }
  }›

lemma correct_watching_learn_no_propa:
  assumes
    L1: ‹atm_of L1 ∈ atms_of_mm (mset `# ran_mf N + NE)› and
    L2: ‹atm_of L2 ∈ atms_of_mm (mset `# ran_mf N + NE)› and
    UW: ‹atms_of (mset UW) ⊆ atms_of_mm (mset `# ran_mf N + NE)› and
    ‹L1 ≠ L2› and
    i_dom: ‹i ∉# dom_m N› and
    ‹⋀L. L ∈# all_lits_of_mm (mset `# ran_mf N + (NE + UE)) ⟹ i ∉ fst ` set (W L)› and
    ‹b ⟷  length (L1 # L2 # UW) = 2›
  shows
  ‹correct_watching (M, fmupd i (L1 # L2 # UW, b') N,
    D, NE, UE, Q, W (L1 := W L1 @ [(i, L2, b)], L2 := W L2 @ [(i, L1, b)])) ⟷
  correct_watching (M, N, D, NE, UE, Q, W)›
  apply (subst correct_watching_learn[OF assms(1-3, 5-6), symmetric])
  unfolding correct_watching.simps clause_to_update_def
  by (auto simp: assms)

lemma restart_nonunit_and_add_wl_restart_nonunit_and_add_l:
  assumes
    SS': ‹(S, S') ∈ {(S, S'). (S, S') ∈ state_wl_l None ∧ correct_watching S}› and
    l_inv: ‹negate_mode_restart_nonunit_l_inv S'› and
    inv: ‹negate_mode_restart_nonunit_wl_inv S› and
    ‹(m, n) ∈ nat_rel› and
    ‹m ∈ {i. i < count_decided (get_trail_wl S)}› and
    ‹n ∈ {i. i < count_decided (get_trail_l S')}› and
    TK: ‹(TK, TK') ∈ find_decomp_target_wl_ref S› and
    [simp]: ‹TK' = (T, K)› and
    [simp]: ‹TK = (T', K')› and
    ij: ‹(i, j) ∈ {(i, j). i = j ∧ i ∉# dom_m (get_clauses_wl T') ∧ i > 0 ∧
       (∀L ∈# all_lits_of_mm (mset `# ran_mf (get_clauses_wl T') + get_unit_clauses_wl T') .
          i ∉ fst ` set (watched_by T' L))}›
  shows ‹restart_nonunit_and_add_wl (DECO_clause_l (get_trail_wl S)) i T'
         ≤ SPEC (λc. (c, restart_nonunit_and_add_l
                          (DECO_clause_l (get_trail_l S')) j T)
                     ∈ {(S, S'').
                        (S, S'') ∈ state_wl_l None ∧ correct_watching S})›
proof -
  have [simp]: ‹i = j›
    using ij by auto
  have le: ‹length (DECO_clause_l (get_trail_wl S)) > 1›
    using SS' l_inv unfolding negate_mode_restart_nonunit_l_inv_def by auto
  then have 1: ‹DECO_clause_l (get_trail_wl S) =
        DECO_clause_l (get_trail_wl S) ! 0 #
           DECO_clause_l (get_trail_wl S) ! Suc 0 # drop 2 (DECO_clause_l (get_trail_wl S))›
    by (cases ‹DECO_clause_l (get_trail_wl S)›; cases ‹tl (DECO_clause_l (get_trail_wl S))›)
      auto
  obtain T U b b' where
      ST: ‹(S, T) ∈ state_wl_l b› and
      ‹no_dup (trail (stateW_of U))› and
      TU: ‹(T, U) ∈ twl_st_l b'›
    using inv unfolding negate_mode_restart_nonunit_wl_inv_def negate_mode_restart_nonunit_l_inv_def
    unfolding twl_struct_invs_def cdclW_restart_mset.cdclW_all_struct_inv_def
      cdclW_restart_mset.cdclW_M_level_inv_def
    by fast
  then have neq: False if ‹DECO_clause_l (get_trail_wl S) ! 0 = DECO_clause_l (get_trail_wl S) ! Suc 0›
    using that
    apply (subst (asm) nth_eq_iff_index_eq)
    using le ST TU by (auto simp: twl_st twl_st_l twl_st_wl distinct_DECO_clause_l)

  show ?thesis
    apply (simp only:  twl_st_wl_splitD[of ‹restart_nonunit_and_add_wl _ _›,
        OF restart_nonunit_and_add_wl.simps]
       twl_st_l_splitD[of ‹restart_nonunit_and_add_l _ _›,
        OF restart_nonunit_and_add_l.simps] Let_def
       assert_bind_spec_conv)
    apply (intro conjI)
    subgoal
      using TK SS' l_inv unfolding negate_mode_restart_nonunit_l_inv_def
         restart_nonunit_and_add_wl_inv_def
      by (cases T') auto
    subgoal
      unfolding RETURN_def less_eq_nres.simps mem_Collect_eq prod.simps singleton_iff
      apply (subst subset_iff)
      unfolding RETURN_def less_eq_nres.simps mem_Collect_eq prod.simps singleton_iff
      apply (intro conjI impI allI)
      subgoal using TK SS' by (auto simp: state_wl_l_def)
      subgoal
        apply (simp only: )
        apply (subst 1)
        apply (subst One_nat_def[symmetric])+
        apply (subst fun_upd_other)
        subgoal
          using SS' length_DECO_clause_l[of ‹get_trail_wl S›] le TK
          by (cases ‹DECO_clause_l (get_trail_wl S)›; cases ‹tl (DECO_clause_l (get_trail_wl S))›)
            (auto simp: DECO_clause_l_DECO_clause[symmetric] twl_st_l twl_st
            simp del: DECO_clause_l_DECO_clause)
        apply (rule correct_watching_learn_no_propa[THEN iffD2])
        apply (rule atms_of_subset_in_atms_ofI[of ‹DECO_clause (get_trail_wl S)›])
        subgoal using TK by (solves ‹auto simp add: mset_take_mset_drop_mset'›)
        subgoal using TK le by (solves ‹auto simp add: mset_take_mset_drop_mset'
          DECO_clause_l_DECO_clause[symmetric]
           simp del: DECO_clause_l_DECO_clause›)
        subgoal apply (use TK le in ‹auto simp add: mset_take_mset_drop_mset' DECO_clause_l_DECO_clause[symmetric]
           simp del: DECO_clause_l_DECO_clause›)
           apply (smt "1" UnE add_mset_add_single image_eqI mset.simps(2) set_mset_mset subsetCE
              union_iff union_single_eq_member)
           done
        subgoal ― ‹TODO Proof›
          using TK le apply (auto simp: mset_take_mset_drop_mset' in_DECO_clause_l_in_DECO_clause_iff
           dest!: in_set_dropD)
           by (metis UnE atms_of_ms_union atms_of_subset_in_atms_ofI)
        subgoal using SS' TK neq by (auto simp add: equality_except_trail_wl_get_clauses_wl)
        subgoal using inv TK SS' ij unfolding negate_mode_restart_nonunit_wl_inv_def
          by (cases S; cases T; cases T')
           (auto simp: state_wl_l_def correct_watching.simps
             clause_to_update_def)
        subgoal using inv TK SS' ij unfolding negate_mode_restart_nonunit_wl_inv_def
          by (cases S; cases T; cases T')
            (auto simp: state_wl_l_def correct_watching.simps
             clause_to_update_def)
        subgoal by (subst 1) auto
        subgoal using inv TK SS' unfolding negate_mode_restart_nonunit_wl_inv_def
          by (cases S; cases T; cases T')
            (auto simp: state_wl_l_def correct_watching.simps
             clause_to_update_def)
        done
      done
    done
qed

lemma negate_mode_restart_nonunit_wl_negate_mode_restart_nonunit_l:
  fixes S :: ‹'v twl_st_wl› and S' :: ‹'v twl_st_l›
  assumes
    SS': ‹(S, S') ∈ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}›
  shows
    ‹negate_mode_restart_nonunit_wl S ≤
      ⇓ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}
       (negate_mode_restart_nonunit_l S')›
proof -
  have fresh: ‹get_fresh_index_wl (get_clauses_wl T) (get_unit_clauses_wl T) (get_watched_wl T)
    ≤ ⇓ {(i, j). i = j ∧ i ∉# dom_m (get_clauses_wl T) ∧ i > 0 ∧
       (∀L ∈# all_lits_of_mm (mset `# ran_mf (get_clauses_wl T) + get_unit_clauses_wl T) .
          i ∉ fst ` set (watched_by T L))}
        (get_fresh_index (get_clauses_l T'))›
    if ‹(TK, TK') ∈ find_decomp_target_wl_ref S› and
      ‹TK = (T, K)› and
      ‹TK' =(T', K')›
    for T T' K K' TK TK'
    using that by (auto simp: get_fresh_index_def equality_except_trail_wl_get_clauses_wl
        get_fresh_index_wl_def watched_by_alt_def
      intro!: RES_refine)
  show ?thesis
    unfolding negate_mode_restart_nonunit_wl_def negate_mode_restart_nonunit_l_def
    apply (refine_rcg find_decomp_target_wl_find_decomp_target_l fresh
      restart_nonunit_and_add_wl_restart_nonunit_and_add_l)
    subgoal using SS' unfolding negate_mode_restart_nonunit_wl_inv_def by blast
    subgoal using SS' by auto
    subgoal using SS' by simp
    subgoal unfolding negate_mode_restart_nonunit_l_inv_def by blast
    subgoal using SS' by fast
    apply assumption+
    apply (rule SS')
    apply assumption+
    done
qed

lemma negate_mode_wl_negate_mode_l:
  fixes S :: ‹'v twl_st_wl› and S' :: ‹'v twl_st_l›
  assumes
    SS': ‹(S, S') ∈ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}› and
    confl: ‹get_conflict_wl S = None›
  shows
    ‹negate_mode_wl S ≤
      ⇓ {(S, S''). (S, S'') ∈ state_wl_l None ∧ correct_watching S}
       (negate_mode_l S')›
proof -
  show ?thesis
    using SS'
    unfolding negate_mode_wl_def negate_mode_l_def
    apply (refine_vcg negate_mode_bj_nonunit_wl_negate_mode_bj_nonunit_l
      negate_mode_bj_unit_wl_negate_mode_bj_unit_l
      negate_mode_restart_nonunit_wl_negate_mode_restart_nonunit_l)
    subgoal unfolding negate_mode_wl_inv_def by blast
    subgoal by auto
    subgoal by auto
    done
qed

context
  fixes P :: ‹'v literal set ⇒ bool›
begin

definition cdcl_twl_enum_inv_wl :: ‹'v twl_st_wl ⇒ bool› where
  ‹cdcl_twl_enum_inv_wl S ⟷
    (∃S'. (S, S') ∈ state_wl_l None ∧ cdcl_twl_enum_inv_l S') ∧
       correct_watching S›

definition cdcl_twl_enum_wl :: ‹'v twl_st_wl ⇒ bool nres› where
  ‹cdcl_twl_enum_wl S = do {
     S ← cdcl_twl_stgy_prog_wl S;
     S ← WHILETcdcl_twl_enum_inv_wl
       (λS. get_conflict_wl S = None ∧ count_decided(get_trail_wl S) > 0 ∧
            ¬P (lits_of_l (get_trail_wl S)))
       (λS. do {
             S ← negate_mode_wl S;
             cdcl_twl_stgy_prog_wl S
           })
       S;
     if get_conflict_wl S = None
     then RETURN (if count_decided(get_trail_wl S) = 0 then P (lits_of_l (get_trail_wl S)) else True)
     else RETURN (False)
    }›

lemma cdcl_twl_enum_wl_cdcl_twl_enum_l:
  assumes
    SS': ‹(S, S') ∈ state_wl_l None› and
    corr: ‹correct_watching S›
  shows
    ‹cdcl_twl_enum_wl S ≤ ⇓ bool_rel
       (cdcl_twl_enum_l P S')›
  unfolding cdcl_twl_enum_wl_def cdcl_twl_enum_l_def
  apply (refine_vcg cdcl_twl_stgy_prog_wl_spec'[unfolded fref_param1, THEN fref_to_Down]
    negate_mode_wl_negate_mode_l)
  subgoal by fast
  subgoal using SS' corr by auto
  subgoal using corr unfolding cdcl_twl_enum_inv_wl_def by blast
  subgoal by auto
  subgoal by auto
  subgoal by auto
  subgoal by auto
  subgoal by auto
  done

end

end