Theory CDCL_W_Abstract_State

theory CDCL_W_Abstract_State
imports CDCL_W_Restart
theory CDCL_W_Abstract_State
imports CDCL_W_Full CDCL_W_Restart

begin

section ‹Instantiation of Weidenbach's CDCL by Multisets›

text ‹We first instantiate the locale of Weidenbach's locale. Then we refine it to a 2-WL program.›

type_synonym 'v cdclW_restart_mset = "('v, 'v clause) ann_lit list ×
  'v clauses ×
  'v clauses ×
  'v clause option"

text ‹We use definition, otherwise we could not use the simplification theorems we have already
  shown.›
fun trail :: "'v cdclW_restart_mset ⇒ ('v, 'v clause) ann_lit list" where
"trail (M, _) = M"

fun init_clss :: "'v cdclW_restart_mset ⇒ 'v clauses" where
"init_clss (_, N, _) = N"

fun learned_clss :: "'v cdclW_restart_mset ⇒ 'v clauses" where
"learned_clss (_, _, U, _) = U"

fun conflicting :: "'v cdclW_restart_mset ⇒ 'v clause option" where
"conflicting (_, _, _, C) = C"

fun cons_trail :: "('v, 'v clause) ann_lit ⇒ 'v cdclW_restart_mset ⇒ 'v cdclW_restart_mset" where
"cons_trail L (M, R) = (L # M, R)"

fun tl_trail where
"tl_trail (M, R) = (tl M, R)"

fun add_learned_cls where
"add_learned_cls C (M, N, U, R) = (M, N, {#C#} + U, R)"

fun remove_cls where
"remove_cls C (M, N, U, R) = (M, removeAll_mset C N, removeAll_mset C U, R)"

fun update_conflicting where
"update_conflicting D (M, N, U,  _) = (M, N, U, D)"

fun init_state where
"init_state N = ([], N, {#}, None)"

declare trail.simps[simp del] cons_trail.simps[simp del] tl_trail.simps[simp del]
  add_learned_cls.simps[simp del] remove_cls.simps[simp del]
  update_conflicting.simps[simp del] init_clss.simps[simp del] learned_clss.simps[simp del]
  conflicting.simps[simp del] init_state.simps[simp del]

lemmas cdclW_restart_mset_state = trail.simps cons_trail.simps tl_trail.simps add_learned_cls.simps
    remove_cls.simps update_conflicting.simps init_clss.simps learned_clss.simps
    conflicting.simps init_state.simps

definition state where
‹state S = (trail S, init_clss S, learned_clss S, conflicting S, ())›

interpretation cdclW_restart_mset: stateW_ops where
  state = state and
  trail = trail and
  init_clss = init_clss and
  learned_clss = learned_clss and
  conflicting = conflicting and

  cons_trail = cons_trail and
  tl_trail = tl_trail and
  add_learned_cls = add_learned_cls and
  remove_cls = remove_cls and
  update_conflicting = update_conflicting and
  init_state = init_state
  .

definition state_eq :: "'v cdclW_restart_mset ⇒ 'v cdclW_restart_mset ⇒ bool" (infix "∼m" 50) where
‹S ∼m T ⟷ state S = state T›

interpretation cdclW_restart_mset: stateW where
  state = state and
  trail = trail and
  init_clss = init_clss and
  learned_clss = learned_clss and
  conflicting = conflicting and
  state_eq = state_eq and
  cons_trail = cons_trail and
  tl_trail = tl_trail and
  add_learned_cls = add_learned_cls and
  remove_cls = remove_cls and
  update_conflicting = update_conflicting and
  init_state = init_state
  by unfold_locales (auto simp: cdclW_restart_mset_state state_eq_def state_def)


abbreviation backtrack_lvl :: "'v cdclW_restart_mset ⇒ nat" where
"backtrack_lvl ≡ cdclW_restart_mset.backtrack_lvl"

interpretation cdclW_restart_mset: conflict_driven_clause_learningW where
  state = state and
  trail = trail and
  init_clss = init_clss and
  learned_clss = learned_clss and
  conflicting = conflicting and

  state_eq = state_eq and
  cons_trail = cons_trail and
  tl_trail = tl_trail and
  add_learned_cls = add_learned_cls and
  remove_cls = remove_cls and
  update_conflicting = update_conflicting and
  init_state = init_state
  by unfold_locales

lemma cdclW_restart_mset_state_eq_eq: "state_eq = (=)"
   apply (intro ext)
   unfolding state_eq_def
   by (auto simp: cdclW_restart_mset_state state_def)


lemma clauses_def: ‹cdclW_restart_mset.clauses (M, N, U, C) = N + U›
  by (subst cdclW_restart_mset.clauses_def) (simp add: cdclW_restart_mset_state)

lemma cdclW_restart_mset_reduce_trail_to:
  "cdclW_restart_mset.reduce_trail_to F S =
    ((if length (trail S) ≥ length F
    then drop (length (trail S) - length F) (trail S)
    else []), init_clss S, learned_clss S, conflicting S)"
    (is "?S = _")
proof (induction F S rule: cdclW_restart_mset.reduce_trail_to.induct)
  case (1 F S) note IH = this
  show ?case
  proof (cases "trail S")
    case Nil
    then show ?thesis using IH by (cases S) (auto simp: cdclW_restart_mset_state)
  next
    case (Cons L M)
    then show ?thesis
      apply (cases "Suc (length M) > length F")
      subgoal
        apply (subgoal_tac "Suc (length M) - length F = Suc (length M - length F)")
        using cdclW_restart_mset.reduce_trail_to_length_ne[of S F] IH by auto
      subgoal
        using IH cdclW_restart_mset.reduce_trail_to_length_ne[of S F]
          apply (cases S)
        by (simp add: cdclW_restart_mset.trail_reduce_trail_to_drop cdclW_restart_mset_state)
      done
  qed
qed


lemma full_cdclW_init_state:
  ‹full cdclW_restart_mset.cdclW_stgy (init_state {#}) S ⟷ S = init_state {#}›
  unfolding full_def rtranclp_unfold
  by (subst tranclp_unfold_begin)
     (auto simp:  cdclW_restart_mset.cdclW_stgy.simps
      cdclW_restart_mset.conflict.simps cdclW_restart_mset.cdclW_o.simps
       cdclW_restart_mset.propagate.simps cdclW_restart_mset.decide.simps
       cdclW_restart_mset.cdclW_bj.simps cdclW_restart_mset.backtrack.simps
      cdclW_restart_mset.skip.simps cdclW_restart_mset.resolve.simps
      cdclW_restart_mset_state clauses_def)

locale twl_restart_ops =
  fixes
    f :: ‹nat ⇒ nat›
begin

interpretation cdclW_restart_mset: cdclW_restart_restart_ops where
  state = state and
  trail = trail and
  init_clss = init_clss and
  learned_clss = learned_clss and
  conflicting = conflicting and

  state_eq = state_eq and
  cons_trail = cons_trail and
  tl_trail = tl_trail and
  add_learned_cls = add_learned_cls and
  remove_cls = remove_cls and
  update_conflicting = update_conflicting and
  init_state = init_state and
  f = f
  by unfold_locales

end

locale twl_restart =
  twl_restart_ops f for f :: ‹nat ⇒ nat› +
  assumes
    f: ‹unbounded f›
begin

interpretation cdclW_restart_mset: cdclW_restart_restart where
  state = state and
  trail = trail and
  init_clss = init_clss and
  learned_clss = learned_clss and
  conflicting = conflicting and

  state_eq = state_eq and
  cons_trail = cons_trail and
  tl_trail = tl_trail and
  add_learned_cls = add_learned_cls and
  remove_cls = remove_cls and
  update_conflicting = update_conflicting and
  init_state = init_state and
  f = f
  by unfold_locales (rule f)

end

context conflict_driven_clause_learningW
begin

lemma distinct_cdclW_state_alt_def:
  ‹distinct_cdclW_state S =
    ((∀T. conflicting S = Some T ⟶ distinct_mset T) ∧
     distinct_mset_mset (clauses S) ∧
     (∀L mark. Propagated L mark ∈ set (trail S) ⟶ distinct_mset mark))›
  unfolding distinct_cdclW_state_def clauses_def
  by auto
end


lemma cdclW_stgy_cdclW_init_state_empty_no_step:
  ‹cdclW_restart_mset.cdclW_stgy (init_state {#}) S ⟷ False›
  unfolding rtranclp_unfold
  by (auto simp:  cdclW_restart_mset.cdclW_stgy.simps
      cdclW_restart_mset.conflict.simps cdclW_restart_mset.cdclW_o.simps
       cdclW_restart_mset.propagate.simps cdclW_restart_mset.decide.simps
       cdclW_restart_mset.cdclW_bj.simps cdclW_restart_mset.backtrack.simps
      cdclW_restart_mset.skip.simps cdclW_restart_mset.resolve.simps
      cdclW_restart_mset_state clauses_def)

lemma cdclW_stgy_cdclW_init_state:
  ‹cdclW_restart_mset.cdclW_stgy** (init_state {#}) S ⟷ S = init_state {#}›
  unfolding rtranclp_unfold
  by (subst tranclp_unfold_begin)
     (auto simp: cdclW_stgy_cdclW_init_state_empty_no_step simp del: init_state.simps)

end