Skip to content

Commit

Permalink
feat: more vector lemmas (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
fgdorais authored Nov 24, 2024
1 parent 0dc51ac commit 44e2d2e
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 53 deletions.
7 changes: 3 additions & 4 deletions Batteries/Data/Vector/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Batteries.Data.List.Basic
import Batteries.Data.List.Lemmas
import Batteries.Tactic.Alias
import Batteries.Tactic.Lint.Misc
import Batteries.Tactic.PrintPrefix

/-!
# Vectors
Expand Down Expand Up @@ -252,7 +253,7 @@ Compares two vectors of the same size using a given boolean relation `r`. `isEqv
`true` if and only if `r v[i] w[i]` is true for all indices `i`.
-/
@[inline] def isEqv (v w : Vector α n) (r : α → α → Bool) : Bool :=
Array.isEqvAux v.toArray w.toArray (by simp) r 0 (by simp)
Array.isEqvAux v.toArray w.toArray (by simp) r n (by simp)

instance [BEq α] : BEq (Vector α n) where
beq a b := isEqv a b (· == ·)
Expand Down Expand Up @@ -294,9 +295,7 @@ Finds the first index of a given value in a vector using `==` for comparison. Re
no element of the index matches the given value.
-/
@[inline] def indexOf? [BEq α] (v : Vector α n) (x : α) : Option (Fin n) :=
match v.toArray.indexOf? x with
| some res => some (res.cast v.size_toArray)
| none => none
(v.toArray.indexOf? x).map (Fin.cast v.size_toArray)

/-- Returns `true` when `v` is a prefix of the vector `w`. -/
@[inline] def isPrefixOf [BEq α] (v : Vector α m) (w : Vector α n) : Bool :=
Expand Down
261 changes: 212 additions & 49 deletions Batteries/Data/Vector/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/-
Copyright (c) 2024 Shreyas Srinivas. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Shreyas Srinivas, Francois Dorais
Authors: Shreyas Srinivas, François G. Dorais
-/

import Batteries.Data.Vector.Basic
Expand All @@ -10,50 +10,13 @@ import Batteries.Data.List.Lemmas
import Batteries.Data.Array.Lemmas
import Batteries.Tactic.Lint.Simp

/-!
## Vectors
Lemmas about `Vector α n`
-/

namespace Batteries

namespace Vector

theorem length_toList {α n} (xs : Vector α n) : xs.toList.length = n := by simp

@[simp] theorem getElem_mk {data : Array α} {size : data.size = n} {i : Nat} (h : i < n) :
(Vector.mk data size)[i] = data[i] := rfl

@[simp] theorem getElem_toArray {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toArray.size) :
xs.toArray[i] = xs[i]'(by simpa using h) := by
cases xs
simp

theorem getElem_toList {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toList.length) :
xs.toList[i] = xs[i]'(by simpa using h) := by simp

@[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :
(Vector.ofFn f)[i] = f ⟨i, by simpa using h⟩ := by
simp [ofFn]

/-- An `empty` vector maps to a `empty` vector. -/
@[simp]
theorem map_empty (f : α → β) : map f empty = empty := by
rw [map, empty, mk.injEq]
exact Array.map_empty f

theorem toArray_injective : ∀ {v w : Vector α n}, v.toArray = w.toArray → v = w
| {..}, {..}, rfl => rfl

/-- A vector of length `0` is an `empty` vector. -/
protected theorem eq_empty (v : Vector α 0) : v = empty := by
apply Vector.toArray_injective
apply Array.eq_empty_of_size_eq_zero v.2

/--
`Vector.ext` is an extensionality theorem.
Vectors `a` and `b` are equal to each other if their elements are equal for each valid index.
-/
@[ext]
protected theorem ext {a b : Vector α n} (h : (i : Nat) → (_ : i < n) → a[i] = b[i]) : a = b := by
apply Vector.toArray_injective
Expand All @@ -63,26 +26,217 @@ protected theorem ext {a b : Vector α n} (h : (i : Nat) → (_ : i < n) → a[i
rw [a.size_toArray] at hi
exact h i hi

@[simp] theorem push_mk {data : Array α} {size : data.size = n} {x : α} :
(Vector.mk data size).push x =
Vector.mk (data.push x) (by simp [size, Nat.succ_eq_add_one]) := rfl
/-! ### mk lemmas -/

theorem toArray_mk (a : Array α) (h : a.size = n) : (Vector.mk a h).toArray = a := rfl

@[simp] theorem allDiff_mk [BEq α] (a : Array α) (h : a.size = n) :
(Vector.mk a h).allDiff = a.allDiff := rfl

@[simp] theorem mk_append_mk (a b : Array α) (ha : a.size = n) (hb : b.size = m) :
Vector.mk a ha ++ Vector.mk b hb = Vector.mk (a ++ b) (by simp [ha, hb]) := rfl

@[simp] theorem back!_mk [Inhabited α] (a : Array α) (h : a.size = n) :
(Vector.mk a h).back! = a.back! := rfl

@[simp] theorem back?_mk (a : Array α) (h : a.size = n) :
(Vector.mk a h).back? = a.back? := rfl

@[simp] theorem drop_mk (a : Array α) (h : a.size = n) (m) :
(Vector.mk a h).drop m = Vector.mk (a.extract m a.size) (by simp [h]) := rfl

@[simp] theorem eraseIdx!_mk (a : Array α) (h : a.size = n) (i) (hi : i < n) :
(Vector.mk a h).eraseIdx! i = Vector.mk (a.eraseIdx! i) (by simp [h, hi]) := by
simp [Vector.eraseIdx!, hi]

@[simp] theorem feraseIdx_mk (a : Array α) (h : a.size = n) (i) :
(Vector.mk a h).feraseIdx i = Vector.mk (a.feraseIdx (i.cast h.symm)) (by simp [h]) := rfl

@[simp] theorem extract_mk (a : Array α) (h : a.size = n) (start stop) :
(Vector.mk a h).extract start stop = Vector.mk (a.extract start stop) (by simp [h]) := rfl

@[simp] theorem getElem_mk (a : Array α) (h : a.size = n) (i) (hi : i < n) :
(Vector.mk a h)[i] = a[i] := rfl

@[simp] theorem get_mk (a : Array α) (h : a.size = n) (i) :
(Vector.mk a h).get i = a.get (i.cast h.symm) := rfl

@[simp] theorem getD_mk (a : Array α) (h : a.size = n) (i x) :
(Vector.mk a h).getD i x = a.getD i x := rfl

@[simp] theorem uget_mk (a : Array α) (h : a.size = n) (i) (hi : i.toNat < n) :
(Vector.mk a h).uget i hi = a.uget i (by simp [h, hi]) := rfl

@[simp] theorem indexOf?_mk [BEq α] (a : Array α) (h : a.size = n) (x : α) :
(Vector.mk a h).indexOf? x = (a.indexOf? x).map (Fin.cast h) := rfl

@[simp] theorem mk_isEqv_mk (r : α → α → Bool) (a b : Array α) (ha : a.size = n) (hb : b.size = n) :
Vector.isEqv (Vector.mk a ha) (Vector.mk b hb) r = Array.isEqv a b r := by
simp [Vector.isEqv, Array.isEqv, ha, hb]

@[simp] theorem mk_isPrefixOf_mk [BEq α] (a b : Array α) (ha : a.size = n) (hb : b.size = m) :
(Vector.mk a ha).isPrefixOf (Vector.mk b hb) = a.isPrefixOf b := rfl

@[simp] theorem map_mk (a : Array α) (h : a.size = n) (f : α → β) :
(Vector.mk a h).map f = Vector.mk (a.map f) (by simp [h]) := rfl

@[simp] theorem pop_mk (a : Array α) (h : a.size = n) :
(Vector.mk a h).pop = Vector.mk a.pop (by simp [h]) := rfl

@[simp] theorem push_mk (a : Array α) (h : a.size = n) (x) :
(Vector.mk a h).push x = Vector.mk (a.push x) (by simp [h]) := rfl

@[simp] theorem reverse_mk (a : Array α) (h : a.size = n) :
(Vector.mk a h).reverse = Vector.mk a.reverse (by simp [h]) := rfl

@[simp] theorem set_mk (a : Array α) (h : a.size = n) (i x) :
(Vector.mk a h).set i x = Vector.mk (a.set (i.cast h.symm) x) (by simp [h]) := rfl

@[simp] theorem set!_mk (a : Array α) (h : a.size = n) (i x) :
(Vector.mk a h).set! i x = Vector.mk (a.set! i x) (by simp [h]) := rfl

@[simp] theorem setD_mk (a : Array α) (h : a.size = n) (i x) :
(Vector.mk a h).setD i x = Vector.mk (a.setD i x) (by simp [h]) := rfl

@[simp] theorem setN_mk (a : Array α) (h : a.size = n) (i x) (hi : i < n) :
(Vector.mk a h).setN i x = Vector.mk (a.setN i x) (by simp [h]) := rfl

@[simp] theorem swap_mk (a : Array α) (h : a.size = n) (i j) :
(Vector.mk a h).swap i j = Vector.mk (a.swap (i.cast h.symm) (j.cast h.symm)) (by simp [h]) :=
rfl

@[simp] theorem swap!_mk (a : Array α) (h : a.size = n) (i j) :
(Vector.mk a h).swap! i j = Vector.mk (a.swap! i j) (by simp [h]) := rfl

@[simp] theorem swapN_mk (a : Array α) (h : a.size = n) (i j) (hi : i < n) (hj : j < n) :
(Vector.mk a h).swapN i j = Vector.mk (a.swapN i j) (by simp [h]) := rfl

@[simp] theorem swapAt_mk (a : Array α) (h : a.size = n) (i x) : (Vector.mk a h).swapAt i x =
((a.swapAt (i.cast h.symm) x).fst, Vector.mk (a.swapAt (i.cast h.symm) x).snd (by simp [h])) :=
rfl

@[simp] theorem pop_mk {data : Array α} {size : data.size = n} :
(Vector.mk data size).pop = Vector.mk data.pop (by simp [size]) := rfl
@[simp] theorem swapAt!_mk (a : Array α) (h : a.size = n) (i x) : (Vector.mk a h).swapAt! i x =
((a.swapAt! i x).fst, Vector.mk (a.swapAt! i x).snd (by simp [h])) := rfl

@[simp] theorem swapAtN_mk (a : Array α) (h : a.size = n) (i x) (hi : i < n) :
(Vector.mk a h).swapAtN i x =
((a.swapAtN i x).fst, Vector.mk (a.swapAtN i x).snd (by simp [h])) := rfl

@[simp] theorem take_mk (a : Array α) (h : a.size = n) (m) :
(Vector.mk a h).take m = Vector.mk (a.take m) (by simp [h]) := rfl

@[simp] theorem mk_zipWith_mk (f : α → β → γ) (a : Array α) (b : Array β)
(ha : a.size = n) (hb : b.size = n) : zipWith (Vector.mk a ha) (Vector.mk b hb) f =
Vector.mk (Array.zipWith a b f) (by simp [ha, hb]) := rfl

@[simp] theorem getElem_toArray {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toArray.size) :
xs.toArray[i] = xs[i]'(by simpa using h) := by
cases xs; simp

/-! ### toArray lemmas -/

@[simp] theorem toArray_append (a : Vector α m) (b : Vector α n) :
(a ++ b).toArray = a.toArray ++ b.toArray := rfl

@[simp] theorem toArray_drop (a : Vector α n) (m) :
(a.drop m).toArray = a.toArray.extract m a.size := rfl

@[simp] theorem toArray_empty : (Vector.empty (α := α)).toArray = #[] := rfl

@[simp] theorem toArray_mkEmpty (cap) :
(Vector.mkEmpty (α := α) cap).toArray = Array.mkEmpty cap := rfl

@[simp] theorem toArray_eraseIdx! (a : Vector α n) (i) (hi : i < n) :
(a.eraseIdx! i).toArray = a.toArray.eraseIdx! i := by
cases a; simp [hi]

@[simp] theorem toArray_eraseIdxN (a : Vector α n) (i) (hi : i < n) :
(a.eraseIdxN i).toArray = a.toArray.eraseIdxN i (by simp [hi]) := rfl

@[simp] theorem toArray_feraseIdx (a : Vector α n) (i) :
(a.feraseIdx i).toArray = a.toArray.feraseIdx (i.cast a.size_toArray.symm) := rfl

@[simp] theorem toArray_extract (a : Vector α n) (start stop) :
(a.extract start stop).toArray = a.toArray.extract start stop := rfl

@[simp] theorem toArray_map (f : α → β) (a : Vector α n) :
(a.map f).toArray = a.toArray.map f := rfl

@[simp] theorem toArray_ofFn (f : Fin n → α) : (Vector.ofFn f).toArray = Array.ofFn f := rfl

@[simp] theorem toArray_pop (a : Vector α n) : a.pop.toArray = a.toArray.pop := rfl

@[simp] theorem toArray_push (a : Vector α n) (x) : (a.push x).toArray = a.toArray.push x := rfl

@[simp] theorem toArray_range : (Vector.range n).toArray = Array.range n := rfl

@[simp] theorem toArray_reverse (a : Vector α n) : a.reverse.toArray = a.toArray.reverse := rfl

@[simp] theorem toArray_set (a : Vector α n) (i x) :
(a.set i x).toArray = a.toArray.set (i.cast a.size_toArray.symm) x := rfl

@[simp] theorem toArray_set! (a : Vector α n) (i x) :
(a.set! i x).toArray = a.toArray.set! i x := rfl

@[simp] theorem toArray_setD (a : Vector α n) (i x) :
(a.setD i x).toArray = a.toArray.setD i x := rfl

@[simp] theorem toArray_setN (a : Vector α n) (i x) (hi : i < n) :
(a.setN i x).toArray = a.toArray.setN i x (by simp [hi]) := rfl

@[simp] theorem toArray_singleton (x : α) : (Vector.singleton x).toArray = #[x] := rfl

@[simp] theorem toArray_swap (a : Vector α n) (i j) : (a.swap i j).toArray =
a.toArray.swap (i.cast a.size_toArray.symm) (j.cast a.size_toArray.symm) := rfl

@[simp] theorem toArray_swap! (a : Vector α n) (i j) :
(a.swap! i j).toArray = a.toArray.swap! i j := rfl

@[simp] theorem toArray_swapN (a : Vector α n) (i j) (hi : i < n) (hj : j < n) :
(a.swapN i j).toArray = a.toArray.swapN i j (by simp [hi]) (by simp [hj]) := rfl

@[simp] theorem toArray_swapAt (a : Vector α n) (i x) :
((a.swapAt i x).fst, (a.swapAt i x).snd.toArray) =
((a.toArray.swapAt (i.cast a.size_toArray.symm) x).fst,
(a.toArray.swapAt (i.cast a.size_toArray.symm) x).snd) := rfl

@[simp] theorem toArray_swapAt! (a : Vector α n) (i x) :
((a.swapAt! i x).fst, (a.swapAt! i x).snd.toArray) =
((a.toArray.swapAt! i x).fst, (a.toArray.swapAt! i x).snd) := rfl

@[simp] theorem toArray_swapAtN (a : Vector α n) (i x) (hi : i < n) :
((a.swapAtN i x).fst, (a.swapAtN i x).snd.toArray) =
((a.toArray.swapAtN i x (by simp [hi])).fst,
(a.toArray.swapAtN i x (by simp [hi])).snd) := rfl

@[simp] theorem toArray_take (a : Vector α n) (m) : (a.take m).toArray = a.toArray.take m := rfl

@[simp] theorem toArray_zipWith (f : α → β → γ) (a : Vector α n) (b : Vector β n) :
(Vector.zipWith a b f).toArray = Array.zipWith a.toArray b.toArray f := rfl

/-! ### toList lemmas -/

theorem length_toList {α n} (xs : Vector α n) : xs.toList.length = n := by simp

theorem getElem_toList {α n} (xs : Vector α n) (i : Nat) (h : i < xs.toList.length) :
xs.toList[i] = xs[i]'(by simpa using h) := by simp

/-! ### getElem lemmas -/

@[simp] theorem getElem_ofFn {α n} (f : Fin n → α) (i : Nat) (h : i < n) :
(Vector.ofFn f)[i] = f ⟨i, by simpa using h⟩ := by
simp [ofFn]

@[simp] theorem getElem_push_last {v : Vector α n} {x : α} : (v.push x)[n] = x := by
rcases v with ⟨data, rfl⟩
simp
rcases v with ⟨_, rfl⟩; simp

-- The `simpNF` linter incorrectly claims that this lemma can not be applied by `simp`.
@[simp, nolint simpNF] theorem getElem_push_lt {v : Vector α n} {x : α} {i : Nat} (h : i < n) :
(v.push x)[i] = v[i] := by
rcases v with ⟨data, rfl⟩
simp [Array.getElem_push_lt, h]
rcases v with ⟨_, rfl⟩; simp [Array.getElem_push_lt, h]

@[simp] theorem getElem_pop {v : Vector α n} {i : Nat} (h : i < n - 1) : (v.pop)[i] = v[i] := by
rcases v with ⟨data, rfl⟩
simp
rcases v with ⟨_, rfl⟩; simp

/--
Variant of `getElem_pop` that will sometimes fire when `getElem_pop` gets stuck because of
Expand All @@ -100,6 +254,15 @@ defeq issues in the implicit size argument.
subst h
simp [pop, back, back!, ← Array.eq_push_pop_back!_of_size_ne_zero]

/-! ### empty lemmas -/

@[simp] theorem map_empty (f : α → β) : map f empty = empty := by
apply toArray_injective; simp

protected theorem eq_empty (v : Vector α 0) : v = empty := by
apply toArray_injective
apply Array.eq_empty_of_size_eq_zero v.2

/-! ### Decidable quantifiers. -/

theorem forall_zero_iff {P : Vector α 0Prop} :
Expand Down

0 comments on commit 44e2d2e

Please sign in to comment.