Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sugar for SatisfiesM #1029

Merged
merged 13 commits into from
Nov 13, 2024
2 changes: 2 additions & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Batteries.Data.UnionFind
import Batteries.Data.Vector
import Batteries.Lean.AttributeExtra
import Batteries.Lean.Delaborator
import Batteries.Lean.EStateM
import Batteries.Lean.Except
import Batteries.Lean.Expr
import Batteries.Lean.Float
Expand All @@ -59,6 +60,7 @@ import Batteries.Lean.NameMapAttribute
import Batteries.Lean.PersistentHashMap
import Batteries.Lean.PersistentHashSet
import Batteries.Lean.Position
import Batteries.Lean.SatisfiesM
import Batteries.Lean.Syntax
import Batteries.Lean.System.IO
import Batteries.Lean.TagAttribute
Expand Down
130 changes: 124 additions & 6 deletions Batteries/Classes/SatisfiesM.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/-
Copyright (c) 2022 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
Authors: Mario Carneiro, Kim Morrison
-/
import Batteries.Lean.EStateM
import Batteries.Lean.Except
import Batteries.Tactic.Lint

/-!
## SatisfiesM
Expand All @@ -12,6 +15,13 @@ and enables Hoare-like reasoning over monadic expressions. For example, given a
function `f : α → m β`, to say that the return value of `f` satisfies `Q` whenever
the input satisfies `P`, we write `∀ a, P a → SatisfiesM Q (f a)`.

For any monad equipped with `MonadSatisfying m`
one can lift `SatisfiesM` to a monadic value in `Subtype`,
using `satisfying x h : m {a // p a}`, where `x : m α` and `h : SatisfiesM p x`.
This includes `Option`, `ReaderT`, `StateT`, and `ExceptT`, and the Lean monad stack.
(Although it is not entirely clear one should treat the Lean monad stack as lawful,
even though Lean accepts this.)

## Notes

`SatisfiesM` is not yet a satisfactory solution for verifying the behaviour of large scale monadic
Expand All @@ -23,7 +33,7 @@ presumably requiring more syntactic support (and smarter `do` blocks) from Lean.
Or it may be that such a solution will look different!
This is an open research program, and for now one should not be overly ambitious using `SatisfiesM`.

In particular lemmas about pure operations on data structures in `batteries` except for `HashMap`
In particular lemmas about pure operations on data structures in `Batteries` except for `HashMap`
should avoid `SatisfiesM` for now, so that it is easy to migrate to other approaches in future.
-/

Expand Down Expand Up @@ -158,25 +168,133 @@ end SatisfiesM
⟨by revert x; intro | .ok _, ⟨.ok ⟨_, h⟩, rfl⟩, _, rfl => exact h,
fun h => match x with | .ok a => ⟨.ok ⟨a, h _ rfl⟩, rfl⟩ | .error e => ⟨.error e, rfl⟩⟩

theorem SatisfiesM_EStateM_eq :
SatisfiesM (m := EStateM ε σ) p x ↔ ∀ s a s', x.run s = .ok a s' → p a := by
constructor
· rintro ⟨x, rfl⟩ s a s' h
match w : x.run s with
| .ok a s' => simp at h; exact h.1
| .error e s' => simp [w] at h
· intro w
refine ⟨?_, ?_⟩
· intro s
match q : x.run s with
| .ok a s' => exact .ok ⟨a, w s a s' q⟩ s'
| .error e s' => exact .error e s'
· ext s
rw [EStateM.run_map, EStateM.run]
split <;> simp_all

@[simp] theorem SatisfiesM_ReaderT_eq [Monad m] :
SatisfiesM (m := ReaderT ρ m) p x ↔ ∀ s, SatisfiesM p (x s) :=
SatisfiesM (m := ReaderT ρ m) p x ↔ ∀ s, SatisfiesM p (x.run s) :=
(exists_congr fun a => by exact ⟨fun eq _ => eq ▸ rfl, funext⟩).trans Classical.skolem.symm

theorem SatisfiesM_StateRefT_eq [Monad m] :
SatisfiesM (m := StateRefT' ω σ m) p x ↔ ∀ s, SatisfiesM p (x s) := by simp
SatisfiesM (m := StateRefT' ω σ m) p x ↔ ∀ s, SatisfiesM p (x s) := by simp [ReaderT.run]

@[simp] theorem SatisfiesM_StateT_eq [Monad m] [LawfulMonad m] :
SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x s) := by
SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x.run s) := by
change SatisfiesM (m := StateT ρ m) (α := α) p x ↔ ∀ s, SatisfiesM (m := m) (p ·.1) (x s)
refine .trans ⟨fun ⟨f, eq⟩ => eq ▸ ?_, fun ⟨f, h⟩ => ?_⟩ Classical.skolem.symm
· refine ⟨fun s => (fun ⟨⟨a, h⟩, s'⟩ => ⟨⟨a, s'⟩, h⟩) <$> f s, fun s => ?_⟩
rw [← comp_map, map_eq_pure_bind]; rfl
· refine ⟨fun s => (fun ⟨⟨a, s'⟩, h⟩ => ⟨⟨a, h⟩, s'⟩) <$> f s, funext fun s => ?_⟩
show _ >>= _ = _; simp [← h]

@[simp] theorem SatisfiesM_ExceptT_eq [Monad m] [LawfulMonad m] :
SatisfiesM (m := ExceptT ρ m) (α := α) p x ↔ SatisfiesM (m := m) (∀ a, · = .ok a → p a) x := by
SatisfiesM (m := ExceptT ρ m) (α := α) p x ↔
SatisfiesM (m := m) (∀ a, · = .ok a → p a) x.run := by
change _ ↔ SatisfiesM (m := m) (∀ a, · = .ok a → p a) x
refine ⟨fun ⟨f, eq⟩ => eq ▸ ?_, fun ⟨f, eq⟩ => eq ▸ ?_⟩
· exists (fun | .ok ⟨a, h⟩ => ⟨.ok a, fun | _, rfl => h⟩ | .error e => ⟨.error e, nofun⟩) <$> f
show _ = _ >>= _; rw [← comp_map, map_eq_pure_bind]; congr; funext a; cases a <;> rfl
· exists ((fun | ⟨.ok a, h⟩ => .ok ⟨a, h _ rfl⟩ | ⟨.error e, _⟩ => .error e) <$> f : m _)
show _ >>= _ = _; simp [← comp_map, ← bind_pure_comp]; congr; funext ⟨a, h⟩; cases a <;> rfl

/--
If a monad has `MonadSatisfying m`, then we can lift a `h : SatisfiesM (m := m) p x` predicate
to monadic value `satisfying x p : m { x // p x }`.

Reader, state, and exception monads have `MonadSatisfying` instances if the base monad does.
-/
class MonadSatisfying (m : Type u → Type v) [Functor m] [LawfulFunctor m]where
digama0 marked this conversation as resolved.
Show resolved Hide resolved
/-- Lift a `SatisfiesM` predicate to a monadic value. -/
satisfying {p : α → Prop} {x : m α} (h : SatisfiesM (m := m) p x) : m {a // p a}
/-- The value of the lifted monadic value is equal to the original monadic value. -/
val_eq {p : α → Prop} {x : m α} (h : SatisfiesM (m := m) p x) : Subtype.val <$> satisfying h = x

export MonadSatisfying (satisfying)

namespace MonadSatisfying

instance : MonadSatisfying Id where
satisfying {α p x} h := ⟨x, by obtain ⟨⟨_, h⟩, rfl⟩ := h; exact h⟩
val_eq {α p x} h := rfl

instance : MonadSatisfying Option where
satisfying {α p x?} h :=
have h' := SatisfiesM_Option_eq.mp h
match x? with
| none => none
| some x => some ⟨x, h' x rfl⟩
val_eq {α p x?} h := by cases x? <;> simp

instance : MonadSatisfying (Except ε) where
satisfying {α p x?} h :=
have h' := SatisfiesM_Except_eq.mp h
match x? with
| .ok x => .ok ⟨x, h' x rfl⟩
| .error e => .error e
val_eq {α p x?} h := by cases x? <;> simp

-- This will be redundant after nightly-2024-11-08.
attribute [ext] ReaderT.ext

instance [Monad m] [LawfulMonad m][MonadSatisfying m] : MonadSatisfying (ReaderT ρ m) where
satisfying {α p x} h :=
have h' := SatisfiesM_ReaderT_eq.mp h
fun r => satisfying (h' r)
val_eq {α p x} h := by
have h' := SatisfiesM_ReaderT_eq.mp h
ext r
rw [ReaderT.run_map, ← MonadSatisfying.val_eq (h' r)]
rfl

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (StateRefT' ω σ m) :=
inferInstanceAs <| MonadSatisfying (ReaderT _ _)

-- This will be redundant after nightly-2024-11-08.
attribute [ext] StateT.ext

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (StateT ρ m) where
satisfying {α p x} h :=
have h' := SatisfiesM_StateT_eq.mp h
fun r => (fun ⟨⟨a, r'⟩, h⟩ => ⟨⟨a, h⟩, r'⟩) <$> satisfying (h' r)
val_eq {α p x} h := by
have h' := SatisfiesM_StateT_eq.mp h
ext r
rw [← MonadSatisfying.val_eq (h' r), StateT.run_map]
simp [StateT.run]

instance [Monad m] [LawfulMonad m] [MonadSatisfying m] : MonadSatisfying (ExceptT ε m) where
satisfying {α p x} h :=
let x' := satisfying (SatisfiesM_ExceptT_eq.mp h)
ExceptT.mk ((fun ⟨y, w⟩ => y.pmap fun a h => ⟨a, w _ h⟩) <$> x')
val_eq {α p x} h:= by
ext
rw [← MonadSatisfying.val_eq (SatisfiesM_ExceptT_eq.mp h)]
simp

instance : MonadSatisfying (EStateM ε σ) where
satisfying {α p x} h :=
have h' := SatisfiesM_EStateM_eq.mp h
fun s => match w : x.run s with
| .ok a s' => .ok ⟨a, h' s a s' w⟩ s'
| .error e s' => .error e s'
val_eq {α p x} h := by
ext s
rw [EStateM.run_map, EStateM.run]
simp only
split <;> simp_all

end MonadSatisfying
2 changes: 1 addition & 1 deletion Batteries/Data/HashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` the
have : m'.1.size > 0 := by
have := Array.size_mapM (m := StateT (ULift Nat) Id) (go .nil) m.buckets.1
simp [SatisfiesM_StateT_eq, SatisfiesM_Id_eq] at this
simp [this, Id.run, StateT.run, m.2.2, m']
simp [this, Id.run, m.2.2, m']
⟨m'.2.1, m'.1, this⟩
where
/-- Inner loop of `filterMap`. Note that this reverses the bucket lists,
Expand Down
40 changes: 40 additions & 0 deletions Batteries/Lean/EStateM.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/

namespace EStateM

namespace Result

/-- Map a function over an `EStateM.Result`, preserving states and errors. -/
def map {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) : Result ε σ β :=
match x with
| .ok a s' => .ok (f a) s'
| .error e s' => .error e s'

@[simp] theorem map_ok {ε σ α β : Type u} (f : α → β) (a : α) (s : σ) :
(Result.ok a s : Result ε σ α).map f = .ok (f a) s := rfl

@[simp] theorem map_error {ε σ α β : Type u} (f : α → β) (e : ε) (s : σ) :
(Result.error e s : Result ε σ α).map f = .error e s := rfl

@[simp] theorem map_eq_ok {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) (b : β) (s : σ) :
x.map f = .ok b s ↔ ∃ a, x = .ok a s ∧ b = f a := by
cases x <;> simp [and_assoc, and_comm, eq_comm]

@[simp] theorem map_eq_error {ε σ α β : Type u} (f : α → β) (x : Result ε σ α) (e : ε) (s : σ) :
x.map f = .error e s ↔ x = .error e s := by
cases x <;> simp [eq_comm]

end Result

@[simp] theorem run_map (f : α → β) (x : EStateM ε σ α) :
(f <$> x).run s = (x.run s).map f := rfl

@[ext] theorem ext {ε σ α : Type u} (x y : EStateM ε σ α) (h : ∀ s, x.run s = y.run s) : x = y := by
funext s
exact h s

end EStateM
51 changes: 50 additions & 1 deletion Batteries/Lean/Except.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,56 @@ import Lean.Util.Trace

open Lean

namespace Except

/-- Visualize an `Except` using a checkmark or a cross. -/
def Except.emoji : Except ε α → String
def emoji : Except ε α → String
| .error _ => crossEmoji
| .ok _ => checkEmoji

@[simp] theorem map_error {ε : Type u} (f : α → β) (e : ε) :
f <$> (.error e : Except ε α) = .error e := rfl

@[simp] theorem map_ok {ε : Type u} (f : α → β) (x : α) :
f <$> (.ok x : Except ε α) = .ok (f x) := rfl

/-- Map a function over an `Except` value, using a proof that the value is `.ok`. -/
def pmap {ε : Type u} {α β : Type v} (x : Except ε α) (f : (a : α) → x = .ok a → β) : Except ε β :=
match x with
| .error e => .error e
| .ok a => .ok (f a rfl)

@[simp] theorem pmap_error {ε : Type u} {α β : Type v} (e : ε)
(f : (a : α) → Except.error e = Except.ok a → β) :
Except.pmap (.error e) f = .error e := rfl

@[simp] theorem pmap_ok {ε : Type u} {α β : Type v} (a : α)
(f : (a' : α) → (.ok a : Except ε α) = .ok a' → β) :
Except.pmap (.ok a) f = .ok (f a rfl) := rfl

@[simp] theorem pmap_id {ε : Type u} {α : Type v} (x : Except ε α) :
x.pmap (fun a _ => a) = x := by cases x <;> simp

@[simp] theorem map_pmap (g : β → γ) (x : Except ε α) (f : (a : α) → x = .ok a → β) :
g <$> x.pmap f = x.pmap fun a h => g (f a h) := by
cases x <;> simp

end Except

namespace ExceptT

-- This will be redundant after nightly-2024-11-08.
attribute [ext] ExceptT.ext

@[simp] theorem run_mk {m : Type u → Type v} (x : m (Except ε α)) : (ExceptT.mk x).run = x := rfl
@[simp] theorem mk_run (x : ExceptT ε m α) : ExceptT.mk (ExceptT.run x) = x := rfl

@[simp]
theorem map_mk [Monad m] [LawfulMonad m] (f : α → β) (x : m (Except ε α)) :
f <$> ExceptT.mk x = ExceptT.mk ((f <$> · ) <$> x) := by
simp only [Functor.map, Except.map, ExceptT.map, map_eq_pure_bind]
congr
funext a
split <;> simp

end ExceptT
36 changes: 36 additions & 0 deletions Batteries/Lean/SatisfiesM.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Kim Morrison
-/
import Batteries.Classes.SatisfiesM
import Lean.Elab.Command

/-!
# Construct `MonadSatisfying` instances for the Lean monad stack.
-/

open Lean Elab Term Tactic Command

instance : LawfulMonad (EIO ε) := inferInstanceAs <| LawfulMonad (EStateM _ _)
instance : LawfulMonad BaseIO := inferInstanceAs <| LawfulMonad (EIO _)
instance : LawfulMonad IO := inferInstanceAs <| LawfulMonad (EIO _)

instance : MonadSatisfying (EIO ε) := inferInstanceAs <| MonadSatisfying (EStateM _ _)
instance : MonadSatisfying BaseIO := inferInstanceAs <| MonadSatisfying (EIO _)
instance : MonadSatisfying IO := inferInstanceAs <| MonadSatisfying (EIO _)

instance : MonadSatisfying CoreM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ (EIO _))

instance : MonadSatisfying MetaM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ CoreM)

instance : MonadSatisfying TermElabM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ <| StateRefT' _ _ MetaM)

instance : MonadSatisfying TacticM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ $ StateRefT' _ _ TermElabM)

instance : MonadSatisfying CommandElabM :=
inferInstanceAs <| MonadSatisfying (ReaderT _ $ StateRefT' _ _ (EIO _))
23 changes: 23 additions & 0 deletions BatteriesTest/satisfying.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import Batteries.Lean.SatisfiesM
import Batteries.Data.Array.Monadic

open Lean Meta Array Elab Term Tactic Command

-- For now these live in the test file, as it's not really clear we want people relying on them!
kim-em marked this conversation as resolved.
Show resolved Hide resolved
instance : LawfulMonad (EIO ε) := inferInstanceAs <| LawfulMonad (EStateM _ _)
instance : LawfulMonad BaseIO := inferInstanceAs <| LawfulMonad (EIO _)
instance : LawfulMonad IO := inferInstanceAs <| LawfulMonad (EIO _)
instance : LawfulMonad CoreM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ (EIO Exception))
instance : LawfulMonad MetaM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ CoreM)
instance : LawfulMonad TermElabM :=
inferInstanceAs <| LawfulMonad (ReaderT _ <| StateRefT' _ _ MetaM)
instance : LawfulMonad TacticM :=
inferInstanceAs <| LawfulMonad (ReaderT _ $ StateRefT' _ _ $ TermElabM)
instance : LawfulMonad CommandElabM :=
inferInstanceAs <| LawfulMonad (ReaderT _ $ StateRefT' _ _ $ EIO _)

example (xs : Array Expr) : MetaM { ts : Array Expr // ts.size = xs.size } := do
let r ← satisfying (xs.size_mapM inferType)
return r