theory CDCL_W_MaxSAT
  imports CDCL_W_Optimal_Model
begin
subsection ‹Partial MAX-SAT›
definition weight_on_clauses where
  ‹weight_on_clauses N⇩S ρ I = (∑C ∈# (filter_mset (λC. I ⊨ C) N⇩S). ρ C)›
definition atms_exactly_m :: ‹'v partial_interp ⇒ 'v clauses ⇒ bool› where
  ‹atms_exactly_m I N ⟷
  total_over_m I (set_mset N) ∧
  atms_of_s I ⊆ atms_of_mm N›
text ‹Partial in the name refers to the fact that not all clauses are soft clauses, not to the fact
  that we consider partial models.›
inductive partial_max_sat :: ‹'v clauses ⇒ 'v clauses ⇒ ('v clause ⇒ nat) ⇒
  'v partial_interp option ⇒ bool› where
  partial_max_sat:
  ‹partial_max_sat N⇩H N⇩S ρ (Some I)›
if
  ‹I ⊨sm N⇩H› and
  ‹atms_exactly_m I ((N⇩H + N⇩S))› and
  ‹consistent_interp I› and
  ‹⋀I'. consistent_interp I' ⟹ atms_exactly_m I' (N⇩H + N⇩S) ⟹ I' ⊨sm N⇩H ⟹
      weight_on_clauses N⇩S ρ I' ≤ weight_on_clauses N⇩S ρ I› |
  partial_max_unsat:
  ‹partial_max_sat N⇩H N⇩S ρ None›
if
  ‹unsatisfiable (set_mset N⇩H)›
inductive partial_min_sat :: ‹'v clauses ⇒ 'v clauses ⇒ ('v clause ⇒ nat) ⇒
  'v partial_interp option ⇒ bool› where
  partial_min_sat:
  ‹partial_min_sat N⇩H N⇩S ρ (Some I)›
if
  ‹I ⊨sm N⇩H› and
  ‹atms_exactly_m I (N⇩H + N⇩S)› and
  ‹consistent_interp I› and
  ‹⋀I'. consistent_interp I' ⟹ atms_exactly_m I' (N⇩H + N⇩S) ⟹ I' ⊨sm N⇩H ⟹
      weight_on_clauses N⇩S ρ I' ≥ weight_on_clauses N⇩S ρ I› |
  partial_min_unsat:
  ‹partial_min_sat N⇩H N⇩S ρ None›
if
  ‹unsatisfiable (set_mset N⇩H)›
lemma atms_exactly_m_finite:
  assumes ‹atms_exactly_m I N›
  shows ‹finite I›
proof -
  have ‹I ⊆ Pos ` (atms_of_mm N) ∪ Neg ` atms_of_mm N›
    using assms by (force simp: total_over_m_def  atms_exactly_m_def lit_in_set_iff_atm
        atms_of_s_def)
  from finite_subset[OF this] show ?thesis by auto
qed
lemma
  fixes N⇩H :: ‹'v clauses›
  assumes ‹satisfiable (set_mset N⇩H)›
  shows sat_partial_max_sat: ‹∃I. partial_max_sat N⇩H N⇩S ρ (Some I)› and
    sat_partial_min_sat: ‹∃I. partial_min_sat N⇩H N⇩S ρ (Some I)›
proof -
  let ?Is = ‹{I. atms_exactly_m I ((N⇩H + N⇩S)) ∧  consistent_interp I ∧
     I ⊨sm N⇩H}›
  let ?Is'= ‹{I. atms_exactly_m I ((N⇩H + N⇩S)) ∧ consistent_interp I ∧
    I ⊨sm N⇩H ∧ finite I}›
  have Is: ‹?Is = ?Is'›
    by (auto simp: atms_of_s_def atms_exactly_m_finite)
  have ‹?Is' ⊆ set_mset ` simple_clss (atms_of_mm (N⇩H + N⇩S))›
    apply rule
    unfolding image_iff
    by (rule_tac x= ‹mset_set x› in bexI)
      (auto simp: simple_clss_def atms_exactly_m_def image_iff
        atms_of_s_def atms_of_def distinct_mset_mset_set consistent_interp_tuatology_mset_set)
  from finite_subset[OF this] have fin: ‹finite ?Is› unfolding Is
    by (auto simp: simple_clss_finite)
  then have fin': ‹finite (weight_on_clauses N⇩S ρ ` ?Is)›
    by auto
  define ρI where
    ‹ρI = Min (weight_on_clauses N⇩S ρ ` ?Is)›
  have nempty: ‹?Is ≠ {}›
  proof -
    obtain I where I:
      ‹total_over_m I (set_mset N⇩H)›
      ‹I ⊨sm N⇩H›
      ‹consistent_interp I›
      ‹atms_of_s I ⊆ atms_of_mm N⇩H›
      using assms unfolding satisfiable_def_min atms_exactly_m_def
      by (auto simp: atms_of_s_def atm_of_def total_over_m_def)
    let ?I = ‹I ∪ Pos ` {x ∈ atms_of_mm N⇩S. x ∉ atm_of ` I}›
    have ‹?I ∈ ?Is›
      using I
      by (auto simp: atms_exactly_m_def total_over_m_alt_def image_iff
          lit_in_set_iff_atm)
        (auto simp: consistent_interp_def uminus_lit_swap)
    then show ?thesis
      by blast
  qed
  have ‹ρI ∈ weight_on_clauses N⇩S ρ ` ?Is›
    unfolding ρI_def
    by (rule Min_in[OF fin']) (use nempty in auto)
  then obtain I :: ‹'v partial_interp› where
    ‹weight_on_clauses N⇩S ρ I = ρI› and
    ‹I ∈ ?Is›
    by blast
  then have H: ‹consistent_interp I' ⟹ atms_exactly_m I' (N⇩H + N⇩S) ⟹ I' ⊨sm N⇩H ⟹
      weight_on_clauses N⇩S ρ I' ≥ weight_on_clauses N⇩S ρ I› for I'
    using Min_le[OF fin', of ‹weight_on_clauses N⇩S ρ I'›]
    unfolding ρI_def[symmetric]
    by auto
  then have ‹partial_min_sat N⇩H N⇩S ρ (Some I)›
    apply -
    by (rule partial_min_sat)
      (use fin ‹I ∈ ?Is› in ‹auto simp: atms_exactly_m_finite›)
  then show ‹∃I. partial_min_sat N⇩H N⇩S ρ (Some I)›
    by fast
  define ρI where
    ‹ρI = Max (weight_on_clauses N⇩S ρ ` ?Is)›
  have ‹ρI ∈ weight_on_clauses N⇩S ρ ` ?Is›
    unfolding ρI_def
    by (rule Max_in[OF fin']) (use nempty in auto)
  then obtain I :: ‹'v partial_interp› where
    ‹weight_on_clauses N⇩S ρ I = ρI› and
    ‹I ∈ ?Is›
    by blast
  then have H: ‹consistent_interp I' ⟹ atms_exactly_m I' (N⇩H + N⇩S) ⟹ I' ⊨m N⇩H ⟹
      weight_on_clauses N⇩S ρ I' ≤ weight_on_clauses N⇩S ρ I› for I'
    using Max_ge[OF fin', of ‹weight_on_clauses N⇩S ρ I'›]
    unfolding ρI_def[symmetric]
    by auto
  then have ‹partial_max_sat N⇩H N⇩S ρ (Some I)›
    apply -
    by (rule partial_max_sat)
      (use fin ‹I ∈ ?Is› in ‹auto simp: atms_exactly_m_finite
        consistent_interp_tuatology_mset_set›)
  then show ‹∃I. partial_max_sat N⇩H N⇩S ρ (Some I)›
    by fast
qed
inductive weight_sat
  :: ‹'v clauses ⇒ ('v literal multiset ⇒ 'a :: linorder) ⇒
    'v literal multiset option ⇒ bool›
where
  weight_sat:
  ‹weight_sat N ρ (Some I)›
if
  ‹set_mset I ⊨sm N› and
  ‹atms_exactly_m (set_mset I) N› and
  ‹consistent_interp (set_mset I)› and
  ‹distinct_mset I›
  ‹⋀I'. consistent_interp (set_mset I') ⟹ atms_exactly_m (set_mset I') N ⟹ distinct_mset I' ⟹
      set_mset I' ⊨sm N ⟹ ρ I' ≥ ρ I› |
  partial_max_unsat:
  ‹weight_sat N ρ None›
if
  ‹unsatisfiable (set_mset N)›
lemma partial_max_sat_is_weight_sat: 
  fixes additional_atm :: ‹'v clause ⇒ 'v› and
    ρ :: ‹'v clause ⇒ nat› and
    N⇩S :: ‹'v clauses›
  defines
    ‹ρ' ≡ (λC. sum_mset
       ((λL. if L ∈ Pos ` additional_atm ` set_mset N⇩S
         then count N⇩S (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)
           * ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)
         else 0) `# C))›
  assumes
    add: ‹⋀C. C ∈# N⇩S ⟹ additional_atm C ∉ atms_of_mm (N⇩H + N⇩S)›
    ‹⋀C D. C ∈# N⇩S ⟹ D ∈# N⇩S ⟹ additional_atm C = additional_atm D ⟷ C = D› and
    w: ‹weight_sat (N⇩H + (λC. add_mset (Pos (additional_atm C)) C) `# N⇩S) ρ' (Some I)›
  shows
    ‹partial_max_sat N⇩H N⇩S ρ (Some {L ∈ set_mset I. atm_of L ∈ atms_of_mm (N⇩H + N⇩S)})›
proof -
  define N where ‹N ≡ N⇩H + (λC. add_mset (Pos (additional_atm C)) C) `# N⇩S›
  define cl_of where ‹cl_of L = (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)› for L
  from w
  have
    ent: ‹set_mset I ⊨sm N› and
    bi: ‹atms_exactly_m (set_mset I) N› and
    cons: ‹consistent_interp (set_mset I)› and
    dist: ‹distinct_mset I› and
    weight: ‹⋀I'. consistent_interp (set_mset I') ⟹ atms_exactly_m (set_mset I') N ⟹
      distinct_mset I' ⟹ set_mset I' ⊨sm N ⟹ ρ' I' ≥ ρ' I›
    unfolding N_def[symmetric]
    by (auto simp: weight_sat.simps)
  let ?I = ‹{L. L ∈# I ∧ atm_of L ∈ atms_of_mm (N⇩H + N⇩S)}›
  have ent': ‹set_mset I ⊨sm N⇩H›
    using ent unfolding true_clss_restrict
    by (auto simp: N_def)
  then have ent': ‹?I ⊨sm N⇩H›
    apply (subst (asm) true_clss_restrict[symmetric])
    apply (rule true_clss_mono_left, assumption)
    apply auto
    done
  have [simp]: ‹atms_of_ms ((λC. add_mset (Pos (additional_atm C)) C) ` set_mset N⇩S) =
    additional_atm ` set_mset N⇩S ∪ atms_of_ms (set_mset N⇩S)›
    by (auto simp: atms_of_ms_def)
  have bi': ‹atms_exactly_m ?I (N⇩H + N⇩S)›
    using bi
    by (auto simp: atms_exactly_m_def total_over_m_def total_over_set_def
        atms_of_s_def N_def)
  have cons': ‹consistent_interp ?I›
    using cons by (auto simp: consistent_interp_def)
  have [simp]: ‹cl_of (Pos (additional_atm xb)) = xb›
    if ‹xb ∈# N⇩S› for xb
    using someI[of ‹λC. additional_atm xb = additional_atm C› xb] add that
    unfolding cl_of_def
    by auto
  let ?I = ‹{L. L ∈# I ∧ atm_of L ∈ atms_of_mm (N⇩H + N⇩S)} ∪ Pos ` additional_atm ` {C ∈ set_mset N⇩S. ¬set_mset I ⊨ C}
    ∪ Neg ` additional_atm ` {C ∈ set_mset N⇩S. set_mset I ⊨ C}›
  have ‹consistent_interp ?I›
    using cons add by (auto simp: consistent_interp_def
        atms_exactly_m_def uminus_lit_swap
        dest: add)
  moreover have ‹atms_exactly_m ?I N›
    using bi
    by (auto simp: N_def atms_exactly_m_def total_over_m_def
        total_over_set_def image_image)
  moreover have ‹?I ⊨sm N›
    using ent by (auto simp: N_def true_clss_def image_image
          atm_of_lit_in_atms_of true_cls_def
        dest!: multi_member_split)
  moreover have ‹set_mset (mset_set ?I) = ?I› and fin: ‹finite ?I›
    by (auto simp: atms_exactly_m_finite)
  moreover have ‹distinct_mset (mset_set ?I)›
    by (auto simp: distinct_mset_mset_set)
  ultimately have ‹ρ' (mset_set ?I) ≥ ρ' I›
    using weight[of ‹mset_set ?I›]
    by argo
  moreover have ‹ρ' (mset_set ?I) ≤ ρ' I›
    using ent
    by (auto simp: ρ'_def sum_mset_inter_restrict[symmetric] mset_set_subset_iff N_def
        intro!: sum_image_mset_mono
        dest!: multi_member_split)
  ultimately have I_I: ‹ρ' (mset_set ?I) = ρ' I›
    by linarith
  have min: ‹weight_on_clauses N⇩S ρ I'
      ≤ weight_on_clauses N⇩S ρ {L. L ∈# I ∧ atm_of L ∈ atms_of_mm (N⇩H + N⇩S)}›
    if
      cons: ‹consistent_interp I'› and
      bit: ‹atms_exactly_m I' (N⇩H + N⇩S)› and
      I': ‹I' ⊨sm N⇩H›
    for I'
  proof -
    let ?I' = ‹I' ∪ Pos ` additional_atm ` {C ∈ set_mset N⇩S. ¬I' ⊨ C}
      ∪ Neg ` additional_atm ` {C ∈ set_mset N⇩S. I' ⊨ C}›
    have ‹consistent_interp ?I'›
      using cons bit add by (auto simp: consistent_interp_def
          atms_exactly_m_def uminus_lit_swap
          dest: add)
    moreover have ‹atms_exactly_m ?I' N›
      using bit
      by (auto simp: N_def atms_exactly_m_def total_over_m_def
          total_over_set_def image_image)
    moreover have ‹?I' ⊨sm N›
      using I' by (auto simp: N_def true_clss_def image_image
          dest!: multi_member_split)
    moreover have ‹set_mset (mset_set ?I') = ?I'› and fin: ‹finite ?I'›
      using bit by (auto simp: atms_exactly_m_finite)
    moreover have ‹distinct_mset (mset_set ?I')›
      by (auto simp: distinct_mset_mset_set)
    ultimately have I'_I: ‹ρ' (mset_set ?I') ≥ ρ' I›
      using weight[of ‹mset_set ?I'›]
      by argo
    have inj: ‹inj_on cl_of (I' ∩ (λx. Pos (additional_atm x)) ` set_mset N⇩S)› for I'
      using add by (auto simp: inj_on_def)
    have we: ‹weight_on_clauses N⇩S ρ I' = sum_mset (ρ `# N⇩S) -
      sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') N⇩S)› for I'
      unfolding weight_on_clauses_def
      apply (subst (3) multiset_partition[of _ ‹(⊨) I'›])
      unfolding image_mset_union sum_mset.union
      by (auto simp: comp_def)
    have H: ‹sum_mset
       (ρ `#
        filter_mset (Not ∘ (⊨) {L. L ∈# I ∧ atm_of L ∈ atms_of_mm (N⇩H + N⇩S)})
         N⇩S) = ρ' I›
            unfolding I_I[symmetric] unfolding ρ'_def cl_of_def[symmetric]
              sum_mset_sum_count if_distrib
            apply (auto simp: sum_mset_sum_count image_image simp flip: sum.inter_restrict
                cong: if_cong)
            apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
            apply ((use inj in auto; fail)+)[2]
            apply (rule sum.cong)
            apply auto[]
            using inj[of ‹set_mset I›] ‹set_mset I ⊨sm N› assms(2)
            apply (auto dest!: multi_member_split simp: N_def image_Int
                atm_of_lit_in_atms_of true_cls_def)[]
            using add apply (auto simp: true_cls_def)
            done
    have ‹(∑x∈(I' ∪ (λx. Pos (additional_atm x)) ` {C. C ∈# N⇩S ∧ ¬ I' ⊨ C} ∪
         (λx. Neg (additional_atm x)) ` {C. C ∈# N⇩S ∧ I' ⊨ C}) ∩
        (λx. Pos (additional_atm x)) ` set_mset N⇩S.
       count N⇩S (cl_of x) * ρ (cl_of x))
    ≤ (∑A∈{a. a ∈# N⇩S ∧ ¬ I' ⊨ a}. count N⇩S A * ρ A)›
      apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
      apply ((use inj in auto; fail)+)[2]
      apply (rule ordered_comm_monoid_add_class.sum_mono2)
      using that add by (auto dest:  simp: N_def
          atms_exactly_m_def)
    then have ‹sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') N⇩S) ≥ ρ' (mset_set ?I')›
      using fin unfolding cl_of_def[symmetric] ρ'_def
      by (auto simp: ρ'_def
          simp add: sum_mset_sum_count image_image simp flip: sum.inter_restrict)
    then have ‹ρ' I ≤ sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') N⇩S)›
      using I'_I by auto
    then show ?thesis
      unfolding we H I_I apply -
      by auto
  qed
  show ?thesis
    apply (rule partial_max_sat.intros)
    subgoal using ent' by auto
    subgoal using bi' by fast
    subgoal using cons' by fast
    subgoal for I'
      by (rule min)
    done
qed
lemma sum_mset_cong:
  ‹(⋀a. a ∈# A ⟹ f a = g a) ⟹ (∑a∈#A. f a) = (∑a∈#A. g a)›
  by (induction A) auto
lemma partial_max_sat_is_weight_sat_distinct: 
  fixes additional_atm :: ‹'v clause ⇒ 'v› and
    ρ :: ‹'v clause ⇒ nat› and
    N⇩S :: ‹'v clauses›
  defines
    ‹ρ' ≡ (λC. sum_mset
       ((λL. if L ∈ Pos ` additional_atm ` set_mset N⇩S
         then ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)
         else 0) `# C))›
  assumes
    ‹distinct_mset N⇩S› and 
    add: ‹⋀C. C ∈# N⇩S ⟹ additional_atm C ∉ atms_of_mm (N⇩H + N⇩S)›
    ‹⋀C D. C ∈# N⇩S ⟹ D ∈# N⇩S ⟹ additional_atm C = additional_atm D ⟷ C = D› and
    w: ‹weight_sat (N⇩H + (λC. add_mset (Pos (additional_atm C)) C) `# N⇩S) ρ' (Some I)›
  shows
    ‹partial_max_sat N⇩H N⇩S ρ (Some {L ∈ set_mset I. atm_of L ∈ atms_of_mm (N⇩H + N⇩S)})›
proof -
  define cl_of where ‹cl_of L = (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)› for L
  have [simp]: ‹cl_of (Pos (additional_atm xb)) = xb›
    if ‹xb ∈# N⇩S› for xb
    using someI[of ‹λC. additional_atm xb = additional_atm C› xb] add that
    unfolding cl_of_def
    by auto
  have ρ': ‹ρ' = (λC. ∑L∈#C. if L ∈ Pos ` additional_atm ` set_mset N⇩S
                 then count N⇩S
                       (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S) *
                      ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# N⇩S)
                 else 0)›
    unfolding cl_of_def[symmetric] ρ'_def
    using assms(2,4) by (auto intro!: ext sum_mset_cong simp: ρ'_def not_in_iff dest!: multi_member_split)
  show ?thesis
    apply (rule partial_max_sat_is_weight_sat[where additional_atm=additional_atm])
    subgoal by (rule assms(3))
    subgoal by (rule assms(4))
    subgoal unfolding ρ'[symmetric] by (rule assms(5))
    done
qed
lemma atms_exactly_m_alt_def:
  ‹atms_exactly_m (set_mset y) N ⟷ atms_of y ⊆ atms_of_mm N ∧
        total_over_m (set_mset y) (set_mset N)›
  by (auto simp: atms_exactly_m_def atms_of_s_def atms_of_def
      atms_of_ms_def dest!: multi_member_split)
lemma atms_exactly_m_alt_def2:
  ‹atms_exactly_m (set_mset y) N ⟷ atms_of y = atms_of_mm N›
  by (metis atms_of_def atms_of_s_def atms_exactly_m_alt_def equalityI order_refl total_over_m_def
      total_over_set_alt_def)
lemma (in conflict_driven_clause_learning⇩W_optimal_weight) full_cdcl_bnb_stgy_weight_sat:
  ‹full cdcl_bnb_stgy (init_state N) T ⟹ distinct_mset_mset N ⟹ weight_sat N ρ (weight T)›
  using full_cdcl_bnb_stgy_no_conflicting_clause_from_init_state[of N T]
  apply (cases ‹weight T = None›)
  subgoal
    by (auto intro!: weight_sat.intros(2))
  subgoal premises p
    using p(1-4,6)
    apply (clarsimp simp only:)
    apply (rule weight_sat.intros(1))
    subgoal by auto
    subgoal by (auto simp: atms_exactly_m_alt_def)
    subgoal by auto
    subgoal by auto
    subgoal for J I'
      using p(5)[of I'] by (auto simp: atms_exactly_m_alt_def2)
    done
  done
end