Skip to content

Commit

Permalink
feat: Switch out all of mem_omega with MetaM [4/?] (#233)
Browse files Browse the repository at this point in the history
### Description:

  This ensures that one part of the infrastructure is written
  completely in terms of MetaM. We will give `simp_mem` the same
  treatment next.

PR stacked on top of #232

### Testing:

No executable defs have changed, cosim succeeds.

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.
  • Loading branch information
bollu authored Oct 31, 2024
1 parent 486bd6e commit 49914dd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 75 deletions.
124 changes: 69 additions & 55 deletions Arm/Memory/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,8 @@ def TacticM.traceLargeMsg


/-- TacticM's omega invoker -/
def omega (bvToNatSimpCtx : Simp.Context) (bvToNatSimprocs : Array Simp.Simprocs) : TacticM Unit := do
withMainContext do
-- https://leanprover.zulipchat.com/#narrow/stream/326056-ICERM22-after-party/topic/Regression.20tests/near/290131280
let .some goal ← LNSymSimpAtStar (← getMainGoal) bvToNatSimpCtx bvToNatSimprocs
| trace[simp_mem.info] "simp [bv_toNat] at * managed to close goal."
replaceMainGoal [goal]
TacticM.withTraceNode' m!"goal post `bv_toNat` reductions (Note: can be large)" do
trace[simp_mem.info] "{goal}"
-- @bollu: TODO: understand what precisely we are recovering from.
withoutRecover do
evalTactic (← `(tactic| bv_omega_bench))
def omega (g : MVarId) (bvToNatSimpCtx : Simp.Context) (bvToNatSimprocs : Array Simp.Simprocs) : MetaM Unit := do
BvOmegaBench.run g bvToNatSimpCtx bvToNatSimprocs

/-
Introduce a new definition into the local context, simplify it using `simp`,
Expand Down Expand Up @@ -326,6 +317,33 @@ def simpAndIntroDef (name : String) (hdefVal : Expr) : TacticM FVarId := do
replaceMainGoal [goal]
return fvar

def simpAndIntroDef' (g : MVarId) (name : String) (hdefVal : Expr) : MetaM (FVarId × MVarId) := do
g.withContext do
let name ← mkFreshUserName <| .mkSimple name
let hdefTy ← inferType hdefVal

/- Simp to gain some more juice out of the defn.. -/
let mut simpTheorems : Array SimpTheorems := #[]
for a in #[`minimal_theory, `bitvec_rules, `bv_toNat] do
let some ext ← (getSimpExtension? a)
| throwError m!"[simp_mem] Internal error: simp attribute {a} not found!"
simpTheorems := simpTheorems.push (← ext.getTheorems)

-- unfold `state_value.
simpTheorems := simpTheorems.push <| ← ({} : SimpTheorems).addDeclToUnfold `state_value
let simpCtx : Simp.Context := {
simpTheorems,
config := { decide := true, failIfUnchanged := false },
congrTheorems := (← Meta.getSimpCongrTheorems)
}
let (simpResult, _stats) ← simp hdefTy simpCtx (simprocs := #[])
let hdefVal ← simpResult.mkCast hdefVal
let hdefTy ← inferType hdefVal

let g ← g.assert name hdefTy hdefVal
let (fvar, g) ← g.intro1P
return (fvar, g)

section Hypotheses

/--
Expand Down Expand Up @@ -467,12 +485,11 @@ def MemLegalProof.omega_def (h : MemLegalProof e) : Expr :=
mkAppN (Expr.const ``mem_legal'.omega_def []) #[e.span.base, e.span.n, h.h]

/-- Add the omega fact from `mem_legal'.def`. -/
def MemLegalProof.addOmegaFacts (h : MemLegalProof e) (args : Array Expr) :
TacticM (Array Expr) := do
withMainContext do
let fvar ← simpAndIntroDef "hmemLegal_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return args.push (Expr.fvar fvar)
def MemLegalProof.addOmegaFacts (h : MemLegalProof e) (g : MVarId) (args : Array Expr) :
MetaM (Array Expr × MVarId) := do
let (fvar, g) ← simpAndIntroDef' g "hmemLegal_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return (args.push (Expr.fvar fvar), g)

/--
info: mem_subset'.omega_def {a : BitVec 64} {an : Nat} {b : BitVec 64} {bn : Nat} (h : mem_subset' a an b bn) :
Expand All @@ -489,12 +506,11 @@ def MemSubsetProof.omega_def (h : MemSubsetProof e) : Expr :=
#[e.sa.base, e.sa.n, e.sb.base, e.sb.n, h.h]

/-- Add the omega fact from `mem_legal'.omega_def` into the main goal. -/
def MemSubsetProof.addOmegaFacts (h : MemSubsetProof e) (args : Array Expr) :
TacticM (Array Expr) := do
withMainContext do
let fvar ← simpAndIntroDef "hmemSubset_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return args.push (Expr.fvar fvar)
def MemSubsetProof.addOmegaFacts (h : MemSubsetProof e) (g : MVarId) (args : Array Expr) :
MetaM (Array Expr × MVarId) := do
let (fvar, g) ← simpAndIntroDef' g "hmemSubset_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return (args.push (Expr.fvar fvar), g)

/--
Build a term corresponding to `mem_separate'.omega_def` which has facts written
Expand All @@ -505,13 +521,11 @@ def MemSeparateProof.omega_def (h : MemSeparateProof e) : Expr :=
#[e.sa.base, e.sa.n, e.sb.base, e.sb.n, h.h]

/-- Add the omega fact from `mem_legal'.omega_def`. -/
def MemSeparateProof.addOmegaFacts (h : MemSeparateProof e) (args : Array Expr) :
TacticM (Array Expr) := do
withMainContext do
-- simp only [bitvec_rules] (failIfUnchanged := false)
let fvar ← simpAndIntroDef "hmemSeparate_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return args.push (Expr.fvar fvar)
def MemSeparateProof.addOmegaFacts (h : MemSeparateProof e) (g : MVarId) (args : Array Expr) :
MetaM (Array Expr × MVarId) := do
let (fvar, g) ← simpAndIntroDef' g "hmemSeparate_omega" h.omega_def
trace[simp_mem.info] "{h}: added omega fact ({h.omega_def})"
return (args.push (Expr.fvar fvar), g)



Expand All @@ -537,7 +551,7 @@ info: Memory.Region.separate'_of_pairwiseSeprate_of_mem_of_mem {mems : List Memo
/-- make `Memory.Region.separate'_of_pairwiseSeprate_of_mem_of_mem i j (by decide) a b rfl rfl`. -/
def MemPairwiseSeparateProof.mem_separate'_of_pairwiseSeparate_of_mem_of_mem
(self : MemPairwiseSeparateProof mems) (i j : Nat) (a b : MemSpanExpr) :
TacticM <| MemSeparateProof ⟨a, b⟩ := do
MetaM <| MemSeparateProof ⟨a, b⟩ := do
let iexpr := mkNatLit i
let jexpr := mkNatLit j
-- i ≠ j
Expand All @@ -562,41 +576,44 @@ Currently, if the list is syntacticaly of the form [x1, ..., xn],
we create hypotheses of the form `mem_separate' xi xj` for all i, j..
This can be generalized to pairwise separation given hypotheses x ∈ xs, x' ∈ xs.
-/
def MemPairwiseSeparateProof.addOmegaFacts (h : MemPairwiseSeparateProof e) (args : Array Expr) :
TacticM (Array Expr) := do
def MemPairwiseSeparateProof.addOmegaFacts (h : MemPairwiseSeparateProof e) (g : MVarId) (args : Array Expr) :
MetaM (Array Expr × MVarId) := do
-- We need to loop over i, j where i < j and extract hypotheses.
-- We need to find the length of the list, and return an `Array MemRegion`.
let mut args := args
let mut g := g
for i in [0:e.xs.size] do
for j in [i+1:e.xs.size] do
let a := e.xs[i]!
let b := e.xs[j]!
args ← TacticM.withTraceNode' m!"Exploiting ({i}, {j}) : {a} ⟂ {b}" do
(args, g) ← TacticM.withTraceNode' m!"Exploiting ({i}, {j}) : {a} ⟂ {b}" do
let proof ← h.mem_separate'_of_pairwiseSeparate_of_mem_of_mem i j a b
TacticM.traceLargeMsg m!"added {← inferType proof.h}" m!"{proof.h}"
proof.addOmegaFacts args
return args
proof.addOmegaFacts g args
return (args, g)
/--
Given a hypothesis, add declarations that would be useful for omega-blasting
-/
def Hypothesis.addOmegaFactsOfHyp (h : Hypothesis) (args : Array Expr) : TacticM (Array Expr) :=
def Hypothesis.addOmegaFactsOfHyp (g : MVarId) (h : Hypothesis) (args : Array Expr) :
MetaM (Array Expr × MVarId) :=
match h with
| Hypothesis.legal h => h.addOmegaFacts args
| Hypothesis.subset h => h.addOmegaFacts args
| Hypothesis.separate h => h.addOmegaFacts args
| Hypothesis.pairwiseSeparate h => h.addOmegaFacts args
| Hypothesis.read_eq _h => return args -- read has no extra `omega` facts.
| Hypothesis.legal h => h.addOmegaFacts g args
| Hypothesis.subset h => h.addOmegaFacts g args
| Hypothesis.separate h => h.addOmegaFacts g args
| Hypothesis.pairwiseSeparate h => h.addOmegaFacts g args
| Hypothesis.read_eq _h => return (args, g) -- read has no extra `omega` facts.

/--
Accumulate all omega defs in `args`.
-/
def Hypothesis.addOmegaFactsOfHyps (hs : List Hypothesis) (args : Array Expr)
: TacticM (Array Expr) := do
def Hypothesis.addOmegaFactsOfHyps (g : MVarId) (hs : List Hypothesis) (args : Array Expr)
: MetaM (Array Expr × MVarId) := do
TacticM.withTraceNode' m!"Adding omega facts from hypotheses" do
let mut args := args
let mut g := g
for h in hs do
args ← h.addOmegaFactsOfHyp args
return args
(args, g) ← h.addOmegaFactsOfHyp g args
return (args, g)

end Hypotheses

Expand Down Expand Up @@ -741,7 +758,7 @@ An example is `mem_lega'.of_omega n a`, which has type:
-/
def proveWithOmega? {α : Type} [ToMessageData α] [OmegaReducible α] (e : α)
(bvToNatSimpCtx : Simp.Context) (bvToNatSimprocs : Array Simp.Simprocs)
(hyps : Array Memory.Hypothesis) : TacticM (Option (Proof α e)) := do
(hyps : Array Memory.Hypothesis) : MetaM (Option (Proof α e)) := do
let proofFromOmegaVal := (OmegaReducible.reduceToOmega e)
-- (h : a.toNat + n ≤ 2 ^ 64) → mem_legal' a n
let proofFromOmegaTy ← inferType (OmegaReducible.reduceToOmega e)
Expand All @@ -753,19 +770,16 @@ def proveWithOmega? {α : Type} [ToMessageData α] [OmegaReducible α] (e : α)
trace[simp_mem.info] "omega obligation '{omegaObligationTy}'"
let omegaObligationVal ← mkFreshExprMVar (type? := omegaObligationTy)
let factProof := mkAppN proofFromOmegaVal #[omegaObligationVal]
let oldGoals := (← getGoals)

let g := omegaObligationVal.mvarId!
g.withContext do
try
setGoals (omegaObligationVal.mvarId! :: (← getGoals))
withMainContext do
let _ ← Hypothesis.addOmegaFactsOfHyps hyps.toList #[]
let (_, g) ← Hypothesis.addOmegaFactsOfHyps g hyps.toList #[]
trace[simp_mem.info] m!"Executing `omega` to close {e}"
omega bvToNatSimpCtx bvToNatSimprocs
omega g bvToNatSimpCtx bvToNatSimprocs
trace[simp_mem.info] "{checkEmoji} `omega` succeeded."
return (.some <| Proof.mk (← instantiateMVars factProof))
catch e =>
trace[simp_mem.info] "{crossEmoji} `omega` failed with error:\n{e.toMessageData}"
setGoals oldGoals
return none
end ReductionToOmega

Expand All @@ -775,7 +789,7 @@ and simplifying all other expressions. return `true` if goal has been closed, an
-/
partial def closeMemSideCondition (g : MVarId)
(bvToNatSimpCtx : Simp.Context) (bvToNatSimprocs : Array Simp.Simprocs)
(hyps : Array Memory.Hypothesis) : TacticM Bool := do
(hyps : Array Memory.Hypothesis) : MetaM Bool := do
g.withContext do
trace[simp_mem.info] "{processingEmoji} Matching on ⊢ {← g.getType}"
let gt ← g.getType
Expand Down
25 changes: 11 additions & 14 deletions Arm/Memory/MemOmega.lean
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@ def init (cfg : Config) : MetaM Context := do
return {cfg, bvToNatSimpCtx, bvToNatSimprocs}
end Context

abbrev MemOmegaM := (ReaderT Context TacticM)
abbrev MemOmegaM := (ReaderT Context MetaM)

namespace MemOmegaM

def run (ctx : Context) (x : MemOmegaM α) : TacticM α := ReaderT.run x ctx

def run (ctx : Context) (x : MemOmegaM α) : MetaM α := ReaderT.run x ctx
end MemOmegaM

def memOmegaTac : MemOmegaM Unit := do
let g ← getMainGoal
def memOmega (g : MVarId) : MemOmegaM Unit := do
g.withContext do
/- We need to explode all pairwise separate hyps -/
let rawHyps ← getLocalHyps
Expand All @@ -85,22 +82,20 @@ def memOmegaTac : MemOmegaM Unit := do
hyps := hyps.filter (!·.isPairwiseSeparate || isPairwiseEnabled)

-- used specialized procedure that doesn't unfold everything for the easy case.
if ← closeMemSideCondition (← getMainGoal) (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs hyps then
if ← closeMemSideCondition g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs hyps then
return ()
else
-- in the bad case, just rip through everything.
-- let _ ← Hypothesis.addOmegaFactsOfHyps (hyps.toList.filter (fun h => h.isPairwiseSeparate)) #[]
let _ ← Hypothesis.addOmegaFactsOfHyps hyps.toList #[]
let (_, g) ← Hypothesis.addOmegaFactsOfHyps g hyps.toList #[]

TacticM.withTraceNode' m!"Reducion to omega" do
try
TacticM.traceLargeMsg m!"goal (Note: can be large)" m!"{← getMainGoal}"
omega (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs
TacticM.traceLargeMsg m!"goal (Note: can be large)" m!"{g}"
omega g (← readThe Context).bvToNatSimpCtx (← readThe Context).bvToNatSimprocs
trace[simp_mem.info] "{checkEmoji} `omega` succeeded."
catch e =>
trace[simp_mem.info] "{crossEmoji} `omega` failed with error:\n{e.toMessageData}"


/--
Allow elaboration of `MemOmegaConfig` arguments to tactics.
-/
Expand All @@ -124,14 +119,16 @@ syntax (name := mem_omega_bang) "mem_omega!" (Lean.Parser.Tactic.config)? : tact
def evalMemOmega : Tactic := fun
| `(tactic| mem_omega $[$cfg]?) => do
let cfg ← elabMemOmegaConfig (mkOptionalNode cfg)
memOmegaTac.run (← Context.init cfg)
liftMetaFinishingTactic fun g => do
memOmega g |>.run (← Context.init cfg)
| _ => throwUnsupportedSyntax

@[tactic mem_omega_bang]
def evalMemOmegaBang : Tactic := fun
| `(tactic| mem_omega! $[$cfg]?) => do
let cfg ← elabMemOmegaConfig (mkOptionalNode cfg)
memOmegaTac.run (← Context.init cfg.mkBang)
liftMetaFinishingTactic fun g => do
memOmega g |>.run (← Context.init cfg.mkBang)
| _ => throwUnsupportedSyntax

end MemOmega
18 changes: 12 additions & 6 deletions Tactics/BvOmegaBench.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ open Lean Elab Meta Tactic Omega

namespace BvOmegaBench


-- Adapted mkSimpContext:
-- from https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Tactic/Simp.lean#L287
/-
Make the `SimpContext` that corresponds to using `bv_toNat`
Adapted mkSimpContext:
from https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Tactic/Simp.lean#L287
-/
def bvOmegaSimpCtx : MetaM (Simp.Context × Array Simp.Simprocs) := do
let mut simprocs := #[]
let mut simpTheorems := #[]
Expand All @@ -47,14 +49,13 @@ Code adapted from:
- https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Tactic/Simp.lean#L406
- https://github.com/leanprover/lean4/blob/master/src/Lean/Elab/Tactic/Omega/Frontend.lean#L685
-/
def run (g : MVarId) : MetaM Unit := do
def run (g : MVarId) (bvToNatSimpCtx : Simp.Context) (bvToNatSimprocs : Array Simp.Simprocs) : MetaM Unit := do
let minMs ← getBvOmegaBenchMinMs
let goalStr ← ppGoal g
let startTime ← IO.monoMsNow
let filePath ← getBvOmegaBenchFilePath
try
g.withContext do
let (bvToNatSimpCtx, bvToNatSimprocs) ← bvOmegaSimpCtx
let nonDepHyps ← g.getNondepPropHyps
let mut g := g

Expand Down Expand Up @@ -101,6 +102,11 @@ def run (g : MVarId) : MetaM Unit := do
h.putStrLn s!"enderror"
throw e

/-- Build the default simp context (bv_toNat) and run omega -/
def runWithDefaultSimpContext (g : MVarId) : MetaM Unit := do
let (bvToNatSimpCtx, bvToNatSimprocs) ← bvOmegaSimpCtx
run g bvToNatSimpCtx bvToNatSimprocs

end BvOmegaBench

syntax (name := bvOmegaBenchTac) "bv_omega_bench" : tactic
Expand All @@ -109,5 +115,5 @@ syntax (name := bvOmegaBenchTac) "bv_omega_bench" : tactic
def bvOmegaBenchImpl : Tactic
| `(tactic| bv_omega_bench) =>
liftMetaFinishingTactic fun g => do
BvOmegaBench.run g
BvOmegaBench.runWithDefaultSimpContext g
| _ => throwUnsupportedSyntax

0 comments on commit 49914dd

Please sign in to comment.