Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Define dependent version of Fin.foldl #1071

Merged
merged 15 commits into from
Dec 4, 2024
Merged
48 changes: 48 additions & 0 deletions Batteries/Data/Fin/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,51 @@ alias enum := Array.finRange

@[deprecated (since := "2024-11-15")]
alias list := List.finRange

/-- Dependent version of `Fin.foldr`. -/
@[inline] def dfoldr (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.succ → α i.castSucc) (init : α (last n)) : α 0 :=
loop n (Nat.lt_succ_self n) init where
/-- Inner loop for `Fin.dfoldr`.
`Fin.dfoldr.loop n α f i h x = f 0 (f 1 (... (f i x)))` -/
@[specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α 0 :=
match i with
| i + 1 => loop i (Nat.lt_of_succ_lt h) (f ⟨i, Nat.lt_of_succ_lt_succ h⟩ x)
| 0 => x

/-- Dependent version of `Fin.foldrM`. -/
@[inline] def dfoldrM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.succ → m (α i.castSucc)) (init : α (last n)) : m (α 0) :=
dfoldr n (fun i => m (α i)) (fun i x => x >>= f i) (pure init)

/-- Dependent version of `Fin.foldl`. -/
@[inline] def dfoldl (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.castSucc → α i.succ) (init : α 0) : α (last n) :=
loop 0 (Nat.zero_lt_succ n) init where
/-- Inner loop for `Fin.dfoldl`. `Fin.dfoldl.loop n α f i h x = f n (f (n-1) (... (f i x)))` -/
@[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : α (last n) :=
if h' : i < n then
loop (i + 1) (Nat.succ_lt_succ h') (f ⟨i, h'⟩ x)
else
haveI : ⟨i, h⟩ = last n := by ext; simp; omega
_root_.cast (congrArg α this) x

/-- Dependent version of `Fin.foldlM`. -/
@[inline] def dfoldlM [Monad m] (n : Nat) (α : Fin (n + 1) → Sort _)
(f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (init : α 0) : m (α (last n)) :=
loop 0 (Nat.zero_lt_succ n) init where
/-- Inner loop for `Fin.dfoldlM`.
```
Fin.foldM.loop n α f i h xᵢ = do
let xᵢ₊₁ ← f i xᵢ
...
let xₙ ← f (n-1) xₙ₋₁
pure xₙ
```
-/
@[semireducible, specialize] loop (i : Nat) (h : i < n + 1) (x : α ⟨i, h⟩) : m (α (last n)) :=
if h' : i < n then
(f ⟨i, h'⟩ x) >>= loop (i + 1) (Nat.succ_lt_succ h')
else
haveI : ⟨i, h⟩ = last n := by ext; simp; omega
_root_.cast (congrArg (fun i => m (α i)) this) (pure x)
137 changes: 136 additions & 1 deletion Batteries/Data/Fin/Fold.lean
Original file line number Diff line number Diff line change
@@ -1,12 +1,147 @@
/-
Copyright (c) 2024 François G. Dorais. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: François G. Dorais
Authors: François G. Dorais, Quang Dao
-/
import Batteries.Data.List.FinRange
import Batteries.Data.Fin.Basic

namespace Fin

/-! ### dfoldr -/

theorem dfoldr_loop_zero (f : (i : Fin n) → α i.succ → α i.castSucc) (x) :
dfoldr.loop n α f 0 (Nat.zero_lt_succ n) x = x := rfl

theorem dfoldr_loop_succ (f : (i : Fin n) → α i.succ → α i.castSucc) (h : i < n) (x) :
dfoldr.loop n α f (i+1) (Nat.add_lt_add_right h 1) x =
dfoldr.loop n α f i (Nat.lt_add_right 1 h) (f ⟨i, h⟩ x) := rfl

theorem dfoldr_loop (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (h : i+1 ≤ n+1) (x) :
dfoldr.loop (n+1) α f (i+1) (Nat.add_lt_add_right h 1) x =
f 0 (dfoldr.loop n (α ∘ succ) (f ·.succ) i h x) := by
induction i with
| zero => rfl
| succ i ih => exact ih ..

@[simp] theorem dfoldr_zero (f : (i : Fin 0) → α i.succ → α i.castSucc) (x) :
dfoldr 0 α f x = x := rfl

theorem dfoldr_succ (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) :
dfoldr (n+1) α f x = f 0 (dfoldr n (α ∘ succ) (f ·.succ) x) := dfoldr_loop ..

theorem dfoldr_succ_last (f : (i : Fin (n+1)) → α i.succ → α i.castSucc) (x) :
dfoldr (n+1) α f x = dfoldr n (α ∘ castSucc) (f ·.castSucc) (f (last n) x) := by
induction n with
| zero => simp only [dfoldr_succ, dfoldr_zero, last, zero_eta]
| succ n ih => rw [dfoldr_succ, ih (α := α ∘ succ) (f ·.succ), dfoldr_succ]; congr

theorem dfoldr_eq_dfoldrM (f : (i : Fin n) → α i.succ → α i.castSucc) (x) :
dfoldr n α f x = dfoldrM (m:=Id) n α f x := rfl

theorem dfoldr_eq_foldr (f : Fin n → α → α) (x : α) : dfoldr n (fun _ => α) f x = foldr n f x := by
induction n with
| zero => simp only [dfoldr_zero, foldr_zero]
| succ n ih => simp only [dfoldr_succ, foldr_succ, Function.comp_apply, Function.comp_def, ih]

/-! ### dfoldrM -/

@[simp] theorem dfoldrM_zero [Monad m] (f : (i : Fin 0) → α i.succ → m (α i.castSucc)) (x) :
dfoldrM 0 α f x = pure x := rfl

theorem dfoldrM_succ [Monad m] (f : (i : Fin (n+1)) → α i.succ → m (α i.castSucc))
(x) : dfoldrM (n+1) α f x = dfoldrM n (α ∘ succ) (f ·.succ) x >>= f 0 := dfoldr_succ ..

theorem dfoldrM_eq_foldrM [Monad m] [LawfulMonad m] (f : (i : Fin n) → α → m α) (x : α) :
dfoldrM n (fun _ => α) f x = foldrM n f x := by
induction n generalizing x with
| zero => simp only [dfoldrM_zero, foldrM_zero]
| succ n ih => simp only [dfoldrM_succ, foldrM_succ, Function.comp_def, ih]

/-! ### dfoldl -/

theorem dfoldl_loop_lt (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (h : i < n) (x) :
dfoldl.loop n α f i (Nat.lt_add_right 1 h) x =
dfoldl.loop n α f (i+1) (Nat.add_lt_add_right h 1) (f ⟨i, h⟩ x) := by
rw [dfoldl.loop, dif_pos h]

theorem dfoldl_loop_eq (f : ∀ (i : Fin n), α i.castSucc → α i.succ) (x) :
dfoldl.loop n α f n (Nat.le_refl _) x = x := by
rw [dfoldl.loop, dif_neg (Nat.lt_irrefl _), cast_eq]

@[simp] theorem dfoldl_zero (f : (i : Fin 0) → α i.castSucc → α i.succ) (x) :
dfoldl 0 α f x = x := dfoldl_loop_eq ..

theorem dfoldl_loop (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (h : i < n+1) (x) :
dfoldl.loop (n+1) α f i (Nat.lt_add_right 1 h) x =
dfoldl.loop n (α ∘ succ) (f ·.succ ·) i h (f ⟨i, h⟩ x) := by
if h' : i < n then
rw [dfoldl_loop_lt _ h _]
rw [dfoldl_loop_lt _ h' _, dfoldl_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [dfoldl_loop_lt]
rw [dfoldl_loop_eq, dfoldl_loop_eq]

theorem dfoldl_succ (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) :
dfoldl (n+1) α f x = dfoldl n (α ∘ succ) (f ·.succ ·) (f 0 x) := dfoldl_loop ..

theorem dfoldl_succ_last (f : (i : Fin (n+1)) → α i.castSucc → α i.succ) (x) :
dfoldl (n+1) α f x = f (last n) (dfoldl n (α ∘ castSucc) (f ·.castSucc ·) x) := by
rw [dfoldl_succ]
induction n with
| zero => simp [dfoldl_succ, last]
| succ n ih => rw [dfoldl_succ, @ih (α ∘ succ) (f ·.succ ·), dfoldl_succ]; congr

theorem dfoldl_eq_foldl (f : Fin n → α → α) (x : α) :
dfoldl n (fun _ => α) f x = foldl n (fun x i => f i x) x := by
induction n generalizing x with
| zero => simp only [dfoldl_zero, foldl_zero]
| succ n ih =>
simp only [dfoldl_succ, foldl_succ, Function.comp_apply, Function.comp_def]
congr; simp only [ih]

/-! ### dfoldlM -/

theorem dfoldlM_loop_lt [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (h : i < n) (x) :
dfoldlM.loop n α f i (Nat.lt_add_right 1 h) x =
(f ⟨i, h⟩ x) >>= (dfoldlM.loop n α f (i+1) (Nat.add_lt_add_right h 1)) := by
rw [dfoldlM.loop, dif_pos h]

theorem dfoldlM_loop_eq [Monad m] (f : ∀ (i : Fin n), α i.castSucc → m (α i.succ)) (x) :
dfoldlM.loop n α f n (Nat.le_refl _) x = pure x := by
rw [dfoldlM.loop, dif_neg (Nat.lt_irrefl _), cast_eq]

@[simp] theorem dfoldlM_zero [Monad m] (f : (i : Fin 0) → α i.castSucc → m (α i.succ)) (x) :
dfoldlM 0 α f x = pure x := dfoldlM_loop_eq ..

theorem dfoldlM_loop [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (h : i < n+1)
(x) : dfoldlM.loop (n+1) α f i (Nat.lt_add_right 1 h) x =
f ⟨i, h⟩ x >>= (dfoldlM.loop n (α ∘ succ) (f ·.succ ·) i h .) := by
if h' : i < n then
rw [dfoldlM_loop_lt _ h _]
congr; funext
rw [dfoldlM_loop_lt _ h' _, dfoldlM_loop]; rfl
else
cases Nat.le_antisymm (Nat.le_of_lt_succ h) (Nat.not_lt.1 h')
rw [dfoldlM_loop_lt]
congr; funext
rw [dfoldlM_loop_eq, dfoldlM_loop_eq]

theorem dfoldlM_succ [Monad m] (f : (i : Fin (n+1)) → α i.castSucc → m (α i.succ)) (x) :
dfoldlM (n+1) α f x = f 0 x >>= (dfoldlM n (α ∘ succ) (f ·.succ ·) .) :=
dfoldlM_loop ..

theorem dfoldlM_eq_foldlM [Monad m] (f : (i : Fin n) → α → m α) (x : α) :
dfoldlM n (fun _ => α) f x = foldlM n (fun x i => f i x) x := by
induction n generalizing x with
| zero => simp only [dfoldlM_zero, foldlM_zero]
| succ n ih =>
simp only [dfoldlM_succ, foldlM_succ, Function.comp_apply, Function.comp_def]
congr; ext; simp only [ih]

/-! ### `Fin.fold{l/r}{M}` equals `List.fold{l/r}{M}` -/

theorem foldlM_eq_foldlM_finRange [Monad m] (f : α → Fin n → m α) (x) :
foldlM n f x = (List.finRange n).foldlM f x := by
induction n generalizing x with
Expand Down
Loading