Theory IsaSAT_Rephase

theory IsaSAT_Rephase
imports IsaSAT_Show
theory IsaSAT_Rephase
  imports IsaSAT_Setup IsaSAT_Show
begin

chapter ‹Rephasing›

text ‹
  We implement the idea in CaDiCaL of rephasing:
  ▪ We remember the best model found so far. It is used as base.
  ▪ We flip the phase saving heuristics between \<^term>‹True›,
   \<^term>‹False›, and random.
›

definition rephase_init :: ‹bool ⇒ bool list ⇒ bool list nres› where
‹rephase_init b φ = do {
  let n = length φ;
  nfoldli [0..<n]
    (λ_. True)
    (λ a φ. do {
       ASSERT(a < length φ);
       RETURN (φ[a := b])
   })
   φ
 }›

lemma rephase_init_spec:
  ‹rephase_init b φ ≤ SPEC(λψ. length ψ = length φ)›
proof -
  show ?thesis
  unfolding rephase_init_def Let_def
  apply (rule nfoldli_rule[where I = ‹λ_ _ ψ. length φ = length ψ›])
  apply (auto dest: in_list_in_setD)
  done
qed


definition copy_phase :: ‹bool list ⇒ bool list ⇒ bool list nres› where
‹copy_phase φ φ' = do {
  ASSERT(length φ = length φ');
  let n = length φ';
  nfoldli [0..<n]
    (λ_. True)
    (λ a φ'. do {
       ASSERT(a < length φ);
       ASSERT(a < length φ');
       RETURN (φ'[a := φ!a])
   })
   φ'
 }›

lemma copy_phase_alt_def:
‹copy_phase φ φ' = do {
  ASSERT(length φ = length φ');
  let n = length φ;
  nfoldli [0..<n]
    (λ_. True)
    (λ a φ'. do {
       ASSERT(a < length φ);
       ASSERT(a < length φ');
       RETURN (φ'[a := φ!a])
   })
   φ'
 }›
  unfolding copy_phase_def
  by (auto simp: ASSERT_same_eq_conv)

lemma copy_phase_spec:
  ‹length φ = length φ' ⟹ copy_phase φ φ' ≤ SPEC(λψ. length ψ = length φ)›
  unfolding copy_phase_def Let_def
  apply (intro ASSERT_leI)
  subgoal by auto
  apply (rule nfoldli_rule[where I = ‹λ_ _ ψ. length φ = length ψ›])
  apply (auto dest: in_list_in_setD)
  done


definition rephase_random :: ‹64 word ⇒ bool list ⇒ bool list nres› where
‹rephase_random b φ = do {
  let n = length φ;
  (_, φ) ← nfoldli [0..<n]
      (λ_. True)
      (λa (state, φ). do {
        ASSERT(a < length φ);
       let state = state * 6364136223846793005 + 1442695040888963407;
       RETURN (state, φ[a := (state < 2147483648)])
     })
     (b, φ);
  RETURN φ
 }›


lemma rephase_random_spec:
  ‹rephase_random b φ ≤ SPEC(λψ. length ψ = length φ)›
  unfolding rephase_random_def Let_def
  apply (refine_vcg nfoldli_rule[where I = ‹λ_ _ (_, ψ). length φ = length ψ›])
  apply (auto dest: in_list_in_setD)
  done


definition phase_rephase :: ‹64 word ⇒ phase_save_heur ⇒ phase_save_heur nres› where
‹phase_rephase = (λb (φ, target_assigned, target, best_assigned, best, end_of_phase, curr_phase, length_phase).
    if b = 0
    then do {
      if curr_phase = 0
      then do {
         φ ← rephase_init False φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 1, length_phase)
      }
      else if curr_phase = 1
      then do {
         φ ← copy_phase best φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 2, length_phase)
      }
      else if curr_phase = 2
      then do {
         φ ← rephase_init True φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 3, length_phase)
      }
      else if curr_phase = 3
      then do {
         φ ← rephase_random end_of_phase φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 4, length_phase)
      }
      else do {
         φ ← copy_phase best φ;
         RETURN (φ, target_assigned, target, best_assigned, best, (1+length_phase)*100+end_of_phase, 0,
            length_phase+1)
      }
    }
    else do {
      if curr_phase = 0
      then do {
         φ ← rephase_init False φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 1, length_phase)
      }
      else if curr_phase = 1
      then do {
         φ ← copy_phase best φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 2, length_phase)
      }
      else if curr_phase = 2
      then do {
         φ ← rephase_init True φ;
         RETURN (φ, target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 3, length_phase)
      }
      else do {
         φ ← copy_phase best φ;
         RETURN (φ, target_assigned, target, best_assigned, best, (1+length_phase)*100+end_of_phase, 0,
           length_phase+1)
     }
    })›

lemma phase_rephase_spec:
  assumes ‹phase_save_heur_rel 𝒜 φ›
  shows ‹phase_rephase b φ ≤ ⇓Id (SPEC(phase_save_heur_rel 𝒜))›
proof -
  obtain φ' target_assigned target best_assigned best end_of_phase curr_phase where
    φ: ‹φ = (φ', target_assigned, target, best_assigned, best, end_of_phase, curr_phase)›
    by (cases φ) auto
  then have [simp]: ‹length φ' = length best›
    using assms by (auto simp: phase_save_heur_rel_def)
  have 1: ‹⇓Id (SPEC(phase_save_heur_rel 𝒜)) ≥
    ⇓Id((λ(φ, target_assigned, target, best_assigned, best, end_of_phase, curr_phase, length_phase).
      if b = 0
      then do {
        if curr_phase = 0 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best,length_phase*100+end_of_phase, 1, length_phase)
        }
       else if curr_phase = 1 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 2, length_phase)
       }
       else if curr_phase = 2 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 3, length_phase)
       }
       else if curr_phase = 3 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 4, length_phase)
       }
       else do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, (1+length_phase)*100+end_of_phase, 0, length_phase+1)
       }
     }
     else do {
        if curr_phase = 0 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best,length_phase*100+end_of_phase, 1, length_phase)
        }
       else if curr_phase = 1 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 2, length_phase)
       }
       else if curr_phase = 2 then  do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, length_phase*100+end_of_phase, 3, length_phase)
       }
       else do {
          φ' ← SPEC (λφ'. length φ = length φ');
          RETURN (φ', target_assigned, target, best_assigned, best, (1+length_phase)*100+end_of_phase, 0,
            length_phase+1)
       }
     }
     ) φ)›
   using assms
   by (cases φ)
    (auto simp: phase_save_heur_rel_def phase_saving_def RES_RETURN_RES)

  show ?thesis
    unfolding phase_rephase_def φ
    apply (simp only: prod.case)
    apply (rule order_trans)
    defer
    apply (rule 1)
    apply (simp only: prod.case φ)
    apply (refine_vcg if_mono rephase_init_spec copy_phase_spec rephase_random_spec)
    apply (auto simp: phase_rephase_def)
    done
qed

definition rephase_heur :: ‹64 word ⇒ restart_heuristics ⇒ restart_heuristics nres› where
  ‹rephase_heur = (λb (fast_ema, slow_ema, restart_info, wasted, φ).
    do {
      φ ← phase_rephase b φ;
      RETURN (fast_ema, slow_ema, restart_info, wasted, φ)
   })›

lemma rephase_heur_spec:
  ‹heuristic_rel 𝒜 heur ⟹ rephase_heur b heur ≤  ⇓Id (SPEC(heuristic_rel 𝒜))›
  unfolding rephase_heur_def
  apply (refine_vcg phase_rephase_spec[THEN order_trans])
  apply (auto simp: heuristic_rel_def)
  done

definition rephase_heur_st :: ‹twl_st_wl_heur ⇒ twl_st_wl_heur nres› where
  ‹rephase_heur_st = (λ(M', arena, D', j, W', vm, clvls, cach, lbd, outl, stats, heur,
       vdom, avdom, lcount, opts, old_arena). do {
      let b = current_restart_phase heur;
      heur ← rephase_heur b heur;
      let _ = isasat_print_progress (current_rephasing_phase heur) b stats lcount;
      RETURN (M', arena, D', j, W', vm, clvls, cach, lbd, outl, stats, heur,
       vdom, avdom, lcount, opts, old_arena)
   })›

lemma rephase_heur_st_spec:
  ‹(S, S') ∈ twl_st_heur ⟹ rephase_heur_st S ≤ SPEC(λS. (S, S') ∈ twl_st_heur)›
  unfolding rephase_heur_st_def
  apply (cases S')
  apply (refine_vcg rephase_heur_spec[THEN order_trans, of ‹all_atms_st S'›])
  apply (simp_all add:  twl_st_heur_def)
  done

definition phase_save_phase :: ‹nat ⇒ phase_save_heur ⇒ phase_save_heur nres› where
‹phase_save_phase = (λn (φ, target_assigned, target, best_assigned, best, end_of_phase, curr_phase). do {
       target ← (if n > target_assigned
          then copy_phase φ target else RETURN target);
       target_assigned ← (if n > target_assigned
          then RETURN n else RETURN target_assigned);
       best ← (if n > best_assigned
          then copy_phase φ best else RETURN best);
       best_assigned ← (if n > best_assigned
          then RETURN n else RETURN best_assigned);
       RETURN (φ, target_assigned, target, best_assigned, best, end_of_phase, curr_phase)
   })›

lemma phase_save_phase_spec:
  assumes ‹phase_save_heur_rel 𝒜 φ›
  shows ‹phase_save_phase n φ ≤ ⇓Id (SPEC(phase_save_heur_rel 𝒜))›
proof -
  obtain φ' target_assigned target best_assigned best end_of_phase curr_phase where
    φ: ‹φ = (φ', target_assigned, target, best_assigned, best, end_of_phase, curr_phase)›
    by (cases φ) auto
  then have [simp]: ‹length φ' = length best›  ‹length target = length best›
    using assms by (auto simp: phase_save_heur_rel_def)
  have 1: ‹⇓Id (SPEC(phase_save_heur_rel 𝒜)) ≥
    ⇓Id((λ(φ, target_assigned, target, best_assigned, best, end_of_phase, curr_phase). do {
        target ← (if n > target_assigned
          then SPEC (λφ'. length φ = length φ') else RETURN target);
        target_assigned ← (if n > target_assigned
          then RETURN n else RETURN target_assigned);
        best ← (if n > best_assigned
          then SPEC (λφ'. length φ = length φ') else RETURN best);
        best_assigned ← (if n > best_assigned
          then RETURN n else RETURN best_assigned);
        RETURN (φ', target_assigned, target, best_assigned, best, end_of_phase, curr_phase)
     }) φ)›
   using assms
   by  (auto simp: phase_save_heur_rel_def phase_saving_def RES_RETURN_RES φ RES_RES_RETURN_RES)

  show ?thesis
    unfolding phase_save_phase_def φ
    apply (simp only: prod.case)
    apply (rule order_trans)
    defer
    apply (rule 1)
    apply (simp only: prod.case φ)
    apply (refine_vcg if_mono rephase_init_spec copy_phase_spec rephase_random_spec)
    apply (auto simp: phase_rephase_def)
    done
qed

definition save_rephase_heur :: ‹nat ⇒ restart_heuristics ⇒ restart_heuristics nres› where
  ‹save_rephase_heur = (λn (fast_ema, slow_ema, restart_info, wasted, φ).
    do {
      φ ← phase_save_phase n φ;
      RETURN (fast_ema, slow_ema, restart_info, wasted, φ)
   })›

lemma save_phase_heur_spec:
  ‹heuristic_rel 𝒜 heur ⟹ save_rephase_heur n heur ≤  ⇓Id (SPEC(heuristic_rel 𝒜))›
  unfolding save_rephase_heur_def
  apply (refine_vcg phase_save_phase_spec[THEN order_trans])
  apply (auto simp: heuristic_rel_def)
  done


definition save_phase_st :: ‹twl_st_wl_heur ⇒ twl_st_wl_heur nres› where
  ‹save_phase_st = (λ(M', arena, D', j, W', vm, clvls, cach, lbd, outl, stats, heur,
       vdom, avdom, lcount, opts, old_arena). do {
      ASSERT(isa_length_trail_pre M');
      let n = isa_length_trail M';
      heur ← save_rephase_heur n heur;
      RETURN (M', arena, D', j, W', vm, clvls, cach, lbd, outl, stats, heur,
       vdom, avdom, lcount, opts, old_arena)
   })›

lemma save_phase_st_spec:
  ‹(S, S') ∈ twl_st_heur ⟹ save_phase_st S ≤ SPEC(λS. (S, S') ∈ twl_st_heur)›
  unfolding save_phase_st_def
  apply (cases S')
  apply (refine_vcg save_phase_heur_spec[THEN order_trans, of ‹all_atms_st S'›])
  apply (simp_all add:  twl_st_heur_def isa_length_trail_pre)
  apply (rule isa_length_trail_pre)
  apply blast
  done



end