Theory CDCL_W_MaxSAT

theory CDCL_W_MaxSAT
imports CDCL_W_Optimal_Model
theory CDCL_W_MaxSAT
  imports CDCL_W_Optimal_Model
begin


subsection ‹Partial MAX-SAT›

definition weight_on_clauses where
  ‹weight_on_clauses NS ρ I = (∑C ∈# (filter_mset (λC. I ⊨ C) NS). ρ C)›

definition atms_exactly_m :: ‹'v partial_interp ⇒ 'v clauses ⇒ bool› where
  ‹atms_exactly_m I N ⟷
  total_over_m I (set_mset N) ∧
  atms_of_s I ⊆ atms_of_mm N›

text ‹Partial in the name refers to the fact that not all clauses are soft clauses, not to the fact
  that we consider partial models.›
inductive partial_max_sat :: ‹'v clauses ⇒ 'v clauses ⇒ ('v clause ⇒ nat) ⇒
  'v partial_interp option ⇒ bool› where
  partial_max_sat:
  ‹partial_max_sat NH NS ρ (Some I)›
if
  ‹I ⊨sm NH and
  ‹atms_exactly_m I ((NH + NS))› and
  ‹consistent_interp I› and
  ‹⋀I'. consistent_interp I' ⟹ atms_exactly_m I' (NH + NS) ⟹ I' ⊨sm NH ⟹
      weight_on_clauses NS ρ I' ≤ weight_on_clauses NS ρ I› |
  partial_max_unsat:
  ‹partial_max_sat NH NS ρ None›
if
  ‹unsatisfiable (set_mset NH)›

inductive partial_min_sat :: ‹'v clauses ⇒ 'v clauses ⇒ ('v clause ⇒ nat) ⇒
  'v partial_interp option ⇒ bool› where
  partial_min_sat:
  ‹partial_min_sat NH NS ρ (Some I)›
if
  ‹I ⊨sm NH and
  ‹atms_exactly_m I (NH + NS)› and
  ‹consistent_interp I› and
  ‹⋀I'. consistent_interp I' ⟹ atms_exactly_m I' (NH + NS) ⟹ I' ⊨sm NH ⟹
      weight_on_clauses NS ρ I' ≥ weight_on_clauses NS ρ I› |
  partial_min_unsat:
  ‹partial_min_sat NH NS ρ None›
if
  ‹unsatisfiable (set_mset NH)›

lemma atms_exactly_m_finite:
  assumes ‹atms_exactly_m I N›
  shows ‹finite I›
proof -
  have ‹I ⊆ Pos ` (atms_of_mm N) ∪ Neg ` atms_of_mm N›
    using assms by (force simp: total_over_m_def  atms_exactly_m_def lit_in_set_iff_atm
        atms_of_s_def)
  from finite_subset[OF this] show ?thesis by auto
qed


lemma
  fixes NH :: ‹'v clauses›
  assumes ‹satisfiable (set_mset NH)›
  shows sat_partial_max_sat: ‹∃I. partial_max_sat NH NS ρ (Some I)› and
    sat_partial_min_sat: ‹∃I. partial_min_sat NH NS ρ (Some I)›
proof -
  let ?Is = ‹{I. atms_exactly_m I ((NH + NS)) ∧  consistent_interp I ∧
     I ⊨sm NH}›
  let ?Is'= ‹{I. atms_exactly_m I ((NH + NS)) ∧ consistent_interp I ∧
    I ⊨sm NH ∧ finite I}›
  have Is: ‹?Is = ?Is'›
    by (auto simp: atms_of_s_def atms_exactly_m_finite)
  have ‹?Is' ⊆ set_mset ` simple_clss (atms_of_mm (NH + NS))›
    apply rule
    unfolding image_iff
    by (rule_tac x= ‹mset_set x› in bexI)
      (auto simp: simple_clss_def atms_exactly_m_def image_iff
        atms_of_s_def atms_of_def distinct_mset_mset_set consistent_interp_tuatology_mset_set)
  from finite_subset[OF this] have fin: ‹finite ?Is› unfolding Is
    by (auto simp: simple_clss_finite)
  then have fin': ‹finite (weight_on_clauses NS ρ ` ?Is)›
    by auto
  define ρI where
    ‹ρI = Min (weight_on_clauses NS ρ ` ?Is)›
  have nempty: ‹?Is ≠ {}›
  proof -
    obtain I where I:
      ‹total_over_m I (set_mset NH)›
      ‹I ⊨sm NH
      ‹consistent_interp I›
      ‹atms_of_s I ⊆ atms_of_mm NH
      using assms unfolding satisfiable_def_min atms_exactly_m_def
      by (auto simp: atms_of_s_def atm_of_def total_over_m_def)
    let ?I = ‹I ∪ Pos ` {x ∈ atms_of_mm NS. x ∉ atm_of ` I}›
    have ‹?I ∈ ?Is›
      using I
      by (auto simp: atms_exactly_m_def total_over_m_alt_def image_iff
          lit_in_set_iff_atm)
        (auto simp: consistent_interp_def uminus_lit_swap)
    then show ?thesis
      by blast
  qed
  have ‹ρI ∈ weight_on_clauses NS ρ ` ?Is›
    unfolding ρI_def
    by (rule Min_in[OF fin']) (use nempty in auto)
  then obtain I :: ‹'v partial_interp› where
    ‹weight_on_clauses NS ρ I = ρI› and
    ‹I ∈ ?Is›
    by blast
  then have H: ‹consistent_interp I' ⟹ atms_exactly_m I' (NH + NS) ⟹ I' ⊨sm NH ⟹
      weight_on_clauses NS ρ I' ≥ weight_on_clauses NS ρ I› for I'
    using Min_le[OF fin', of ‹weight_on_clauses NS ρ I'›]
    unfolding ρI_def[symmetric]
    by auto
  then have ‹partial_min_sat NH NS ρ (Some I)›
    apply -
    by (rule partial_min_sat)
      (use fin ‹I ∈ ?Is› in ‹auto simp: atms_exactly_m_finite›)
  then show ‹∃I. partial_min_sat NH NS ρ (Some I)›
    by fast

  define ρI where
    ‹ρI = Max (weight_on_clauses NS ρ ` ?Is)›
  have ‹ρI ∈ weight_on_clauses NS ρ ` ?Is›
    unfolding ρI_def
    by (rule Max_in[OF fin']) (use nempty in auto)
  then obtain I :: ‹'v partial_interp› where
    ‹weight_on_clauses NS ρ I = ρI› and
    ‹I ∈ ?Is›
    by blast
  then have H: ‹consistent_interp I' ⟹ atms_exactly_m I' (NH + NS) ⟹ I' ⊨m NH ⟹
      weight_on_clauses NS ρ I' ≤ weight_on_clauses NS ρ I› for I'
    using Max_ge[OF fin', of ‹weight_on_clauses NS ρ I'›]
    unfolding ρI_def[symmetric]
    by auto
  then have ‹partial_max_sat NH NS ρ (Some I)›
    apply -
    by (rule partial_max_sat)
      (use fin ‹I ∈ ?Is› in ‹auto simp: atms_exactly_m_finite
        consistent_interp_tuatology_mset_set›)
  then show ‹∃I. partial_max_sat NH NS ρ (Some I)›
    by fast
qed

inductive weight_sat
  :: ‹'v clauses ⇒ ('v literal multiset ⇒ 'a :: linorder) ⇒
    'v literal multiset option ⇒ bool›
where
  weight_sat:
  ‹weight_sat N ρ (Some I)›
if
  ‹set_mset I ⊨sm N› and
  ‹atms_exactly_m (set_mset I) N› and
  ‹consistent_interp (set_mset I)› and
  ‹distinct_mset I›
  ‹⋀I'. consistent_interp (set_mset I') ⟹ atms_exactly_m (set_mset I') N ⟹ distinct_mset I' ⟹
      set_mset I' ⊨sm N ⟹ ρ I' ≥ ρ I› |
  partial_max_unsat:
  ‹weight_sat N ρ None›
if
  ‹unsatisfiable (set_mset N)›

lemma partial_max_sat_is_weight_sat: 
  fixes additional_atm :: ‹'v clause ⇒ 'v› and
    ρ :: ‹'v clause ⇒ nat› and
    NS :: ‹'v clauses›
  defines
    ‹ρ' ≡ (λC. sum_mset
       ((λL. if L ∈ Pos ` additional_atm ` set_mset NS
         then count NS (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)
           * ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)
         else 0) `# C))›
  assumes
    add: ‹⋀C. C ∈# NS ⟹ additional_atm C ∉ atms_of_mm (NH + NS)›
    ‹⋀C D. C ∈# NS ⟹ D ∈# NS ⟹ additional_atm C = additional_atm D ⟷ C = D› and
    w: ‹weight_sat (NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS) ρ' (Some I)›
  shows
    ‹partial_max_sat NH NS ρ (Some {L ∈ set_mset I. atm_of L ∈ atms_of_mm (NH + NS)})›
proof -
  define N where ‹N ≡ NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS
  define cl_of where ‹cl_of L = (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)› for L
  from w
  have
    ent: ‹set_mset I ⊨sm N› and
    bi: ‹atms_exactly_m (set_mset I) N› and
    cons: ‹consistent_interp (set_mset I)› and
    dist: ‹distinct_mset I› and
    weight: ‹⋀I'. consistent_interp (set_mset I') ⟹ atms_exactly_m (set_mset I') N ⟹
      distinct_mset I' ⟹ set_mset I' ⊨sm N ⟹ ρ' I' ≥ ρ' I›
    unfolding N_def[symmetric]
    by (auto simp: weight_sat.simps)
  let ?I = ‹{L. L ∈# I ∧ atm_of L ∈ atms_of_mm (NH + NS)}›
  have ent': ‹set_mset I ⊨sm NH
    using ent unfolding true_clss_restrict
    by (auto simp: N_def)
  then have ent': ‹?I ⊨sm NH
    apply (subst (asm) true_clss_restrict[symmetric])
    apply (rule true_clss_mono_left, assumption)
    apply auto
    done
  have [simp]: ‹atms_of_ms ((λC. add_mset (Pos (additional_atm C)) C) ` set_mset NS) =
    additional_atm ` set_mset NS ∪ atms_of_ms (set_mset NS)›
    by (auto simp: atms_of_ms_def)
  have bi': ‹atms_exactly_m ?I (NH + NS)›
    using bi
    by (auto simp: atms_exactly_m_def total_over_m_def total_over_set_def
        atms_of_s_def N_def)
  have cons': ‹consistent_interp ?I›
    using cons by (auto simp: consistent_interp_def)
  have [simp]: ‹cl_of (Pos (additional_atm xb)) = xb›
    if ‹xb ∈# NS for xb
    using someI[of ‹λC. additional_atm xb = additional_atm C› xb] add that
    unfolding cl_of_def
    by auto

  let ?I = ‹{L. L ∈# I ∧ atm_of L ∈ atms_of_mm (NH + NS)} ∪ Pos ` additional_atm ` {C ∈ set_mset NS. ¬set_mset I ⊨ C}
    ∪ Neg ` additional_atm ` {C ∈ set_mset NS. set_mset I ⊨ C}›
  have ‹consistent_interp ?I›
    using cons add by (auto simp: consistent_interp_def
        atms_exactly_m_def uminus_lit_swap
        dest: add)
  moreover have ‹atms_exactly_m ?I N›
    using bi
    by (auto simp: N_def atms_exactly_m_def total_over_m_def
        total_over_set_def image_image)
  moreover have ‹?I ⊨sm N›
    using ent by (auto simp: N_def true_clss_def image_image
          atm_of_lit_in_atms_of true_cls_def
        dest!: multi_member_split)
  moreover have ‹set_mset (mset_set ?I) = ?I› and fin: ‹finite ?I›
    by (auto simp: atms_exactly_m_finite)
  moreover have ‹distinct_mset (mset_set ?I)›
    by (auto simp: distinct_mset_mset_set)
  ultimately have ‹ρ' (mset_set ?I) ≥ ρ' I›
    using weight[of ‹mset_set ?I›]
    by argo
  moreover have ‹ρ' (mset_set ?I) ≤ ρ' I›
    using ent
    by (auto simp: ρ'_def sum_mset_inter_restrict[symmetric] mset_set_subset_iff N_def
        intro!: sum_image_mset_mono
        dest!: multi_member_split)
  ultimately have I_I: ‹ρ' (mset_set ?I) = ρ' I›
    by linarith

  have min: ‹weight_on_clauses NS ρ I'
      ≤ weight_on_clauses NS ρ {L. L ∈# I ∧ atm_of L ∈ atms_of_mm (NH + NS)}›
    if
      cons: ‹consistent_interp I'› and
      bit: ‹atms_exactly_m I' (NH + NS)› and
      I': ‹I' ⊨sm NH
    for I'
  proof -
    let ?I' = ‹I' ∪ Pos ` additional_atm ` {C ∈ set_mset NS. ¬I' ⊨ C}
      ∪ Neg ` additional_atm ` {C ∈ set_mset NS. I' ⊨ C}›
    have ‹consistent_interp ?I'›
      using cons bit add by (auto simp: consistent_interp_def
          atms_exactly_m_def uminus_lit_swap
          dest: add)
    moreover have ‹atms_exactly_m ?I' N›
      using bit
      by (auto simp: N_def atms_exactly_m_def total_over_m_def
          total_over_set_def image_image)
    moreover have ‹?I' ⊨sm N›
      using I' by (auto simp: N_def true_clss_def image_image
          dest!: multi_member_split)
    moreover have ‹set_mset (mset_set ?I') = ?I'› and fin: ‹finite ?I'›
      using bit by (auto simp: atms_exactly_m_finite)
    moreover have ‹distinct_mset (mset_set ?I')›
      by (auto simp: distinct_mset_mset_set)
    ultimately have I'_I: ‹ρ' (mset_set ?I') ≥ ρ' I›
      using weight[of ‹mset_set ?I'›]
      by argo
    have inj: ‹inj_on cl_of (I' ∩ (λx. Pos (additional_atm x)) ` set_mset NS)› for I'
      using add by (auto simp: inj_on_def)

    have we: ‹weight_on_clauses NS ρ I' = sum_mset (ρ `# NS) -
      sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') NS)› for I'
      unfolding weight_on_clauses_def
      apply (subst (3) multiset_partition[of _ ‹(⊨) I'›])
      unfolding image_mset_union sum_mset.union
      by (auto simp: comp_def)
    have H: ‹sum_mset
       (ρ `#
        filter_mset (Not ∘ (⊨) {L. L ∈# I ∧ atm_of L ∈ atms_of_mm (NH + NS)})
         NS) = ρ' I›
            unfolding I_I[symmetric] unfolding ρ'_def cl_of_def[symmetric]
              sum_mset_sum_count if_distrib
            apply (auto simp: sum_mset_sum_count image_image simp flip: sum.inter_restrict
                cong: if_cong)
            apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
            apply ((use inj in auto; fail)+)[2]
            apply (rule sum.cong)
            apply auto[]
            using inj[of ‹set_mset I›] ‹set_mset I ⊨sm N› assms(2)
            apply (auto dest!: multi_member_split simp: N_def image_Int
                atm_of_lit_in_atms_of true_cls_def)[]
            using add apply (auto simp: true_cls_def)
            done
    have ‹(∑x∈(I' ∪ (λx. Pos (additional_atm x)) ` {C. C ∈# NS ∧ ¬ I' ⊨ C} ∪
         (λx. Neg (additional_atm x)) ` {C. C ∈# NS ∧ I' ⊨ C}) ∩
        (λx. Pos (additional_atm x)) ` set_mset NS.
       count NS (cl_of x) * ρ (cl_of x))
    ≤ (∑A∈{a. a ∈# NS ∧ ¬ I' ⊨ a}. count NS A * ρ A)›
      apply (subst comm_monoid_add_class.sum.reindex_cong[symmetric, of cl_of, OF _ refl])
      apply ((use inj in auto; fail)+)[2]
      apply (rule ordered_comm_monoid_add_class.sum_mono2)
      using that add by (auto dest:  simp: N_def
          atms_exactly_m_def)
    then have ‹sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') NS) ≥ ρ' (mset_set ?I')›
      using fin unfolding cl_of_def[symmetric] ρ'_def
      by (auto simp: ρ'_def
          simp add: sum_mset_sum_count image_image simp flip: sum.inter_restrict)
    then have ‹ρ' I ≤ sum_mset (ρ `# filter_mset (Not ∘ (⊨) I') NS)›
      using I'_I by auto
    then show ?thesis
      unfolding we H I_I apply -
      by auto
  qed

  show ?thesis
    apply (rule partial_max_sat.intros)
    subgoal using ent' by auto
    subgoal using bi' by fast
    subgoal using cons' by fast
    subgoal for I'
      by (rule min)
    done
qed

lemma sum_mset_cong:
  ‹(⋀a. a ∈# A ⟹ f a = g a) ⟹ (∑a∈#A. f a) = (∑a∈#A. g a)›
  by (induction A) auto

lemma partial_max_sat_is_weight_sat_distinct: 
  fixes additional_atm :: ‹'v clause ⇒ 'v› and
    ρ :: ‹'v clause ⇒ nat› and
    NS :: ‹'v clauses›
  defines
    ‹ρ' ≡ (λC. sum_mset
       ((λL. if L ∈ Pos ` additional_atm ` set_mset NS
         then ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)
         else 0) `# C))›
  assumes
    ‹distinct_mset NS and ―‹This is implicit on paper›
    add: ‹⋀C. C ∈# NS ⟹ additional_atm C ∉ atms_of_mm (NH + NS)›
    ‹⋀C D. C ∈# NS ⟹ D ∈# NS ⟹ additional_atm C = additional_atm D ⟷ C = D› and
    w: ‹weight_sat (NH + (λC. add_mset (Pos (additional_atm C)) C) `# NS) ρ' (Some I)›
  shows
    ‹partial_max_sat NH NS ρ (Some {L ∈ set_mset I. atm_of L ∈ atms_of_mm (NH + NS)})›
proof -
  define cl_of where ‹cl_of L = (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)› for L
  have [simp]: ‹cl_of (Pos (additional_atm xb)) = xb›
    if ‹xb ∈# NS for xb
    using someI[of ‹λC. additional_atm xb = additional_atm C› xb] add that
    unfolding cl_of_def
    by auto
  have ρ': ‹ρ' = (λC. ∑L∈#C. if L ∈ Pos ` additional_atm ` set_mset NS
                 then count NS
                       (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS) *
                      ρ (SOME C. L = Pos (additional_atm C) ∧ C ∈# NS)
                 else 0)›
    unfolding cl_of_def[symmetric] ρ'_def
    using assms(2,4) by (auto intro!: ext sum_mset_cong simp: ρ'_def not_in_iff dest!: multi_member_split)
  show ?thesis
    apply (rule partial_max_sat_is_weight_sat[where additional_atm=additional_atm])
    subgoal by (rule assms(3))
    subgoal by (rule assms(4))
    subgoal unfolding ρ'[symmetric] by (rule assms(5))
    done
qed

lemma atms_exactly_m_alt_def:
  ‹atms_exactly_m (set_mset y) N ⟷ atms_of y ⊆ atms_of_mm N ∧
        total_over_m (set_mset y) (set_mset N)›
  by (auto simp: atms_exactly_m_def atms_of_s_def atms_of_def
      atms_of_ms_def dest!: multi_member_split)

lemma atms_exactly_m_alt_def2:
  ‹atms_exactly_m (set_mset y) N ⟷ atms_of y = atms_of_mm N›
  by (metis atms_of_def atms_of_s_def atms_exactly_m_alt_def equalityI order_refl total_over_m_def
      total_over_set_alt_def)

lemma (in conflict_driven_clause_learningW_optimal_weight) full_cdcl_bnb_stgy_weight_sat:
  ‹full cdcl_bnb_stgy (init_state N) T ⟹ distinct_mset_mset N ⟹ weight_sat N ρ (weight T)›
  using full_cdcl_bnb_stgy_no_conflicting_clause_from_init_state[of N T]
  apply (cases ‹weight T = None›)
  subgoal
    by (auto intro!: weight_sat.intros(2))
  subgoal premises p
    using p(1-4,6)
    apply (clarsimp simp only:)
    apply (rule weight_sat.intros(1))
    subgoal by auto
    subgoal by (auto simp: atms_exactly_m_alt_def)
    subgoal by auto
    subgoal by auto
    subgoal for J I'
      using p(5)[of I'] by (auto simp: atms_exactly_m_alt_def2)
    done
  done

end