Skip to content

Commit

Permalink
refactor: implement BinaryHeap using Vector (#850)
Browse files Browse the repository at this point in the history
Co-authored-by: Mario Carneiro <[email protected]>
  • Loading branch information
fgdorais and digama0 authored Nov 11, 2024
1 parent 325d06c commit 1bfaec4
Showing 1 changed file with 91 additions and 83 deletions.
174 changes: 91 additions & 83 deletions Batteries/Data/BinaryHeap.lean
Original file line number Diff line number Diff line change
@@ -1,78 +1,76 @@
/-
Copyright (c) 2021 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
Authors: Mario Carneiro, François G. Dorais
-/

import Batteries.Data.Vector.Basic

namespace Batteries

/-- A max-heap data structure. -/
structure BinaryHeap (α) (lt : α → α → Bool) where
/-- Backing array for `BinaryHeap`. -/
/-- `O(1)`. Get data array for a `BinaryHeap`. -/
arr : Array α

namespace BinaryHeap

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
private def maxChild (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) : Option (Fin sz) :=
let left := 2 * i.1 + 1
let right := left + 1
have left_le : i ≤ left := Nat.le_trans
(by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i)
(Nat.le_add_right ..)
have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..)
have i_le : i ≤ i := Nat.le_refl _
have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
have j := if h : right < a.size then
if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j
if h : i.1 = j then ⟨a, rfl⟩ else
let a' := a.swap i j
let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2
have : a'.size - j < a.size - i := by
rw [a.size_swap i j]; exact Nat.sub_lt_sub_left i.2 <| Nat.lt_of_le_of_ne j.2 h
let ⟨a₂, h₂⟩ := heapifyDown lt a' j'
⟨a₂, h₂.trans (a.size_swap i j)⟩
termination_by a.size - i

@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2
if hleft : left < sz then
if hright : right < sz then
if lt a[left] a[right] then
some ⟨right, hright⟩
else
some ⟨left, hleft⟩
else
some ⟨left, hleft⟩
else none

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
def heapifyDown (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) :
Vector α sz :=
match h : maxChild lt a i with
| none => a
| some j =>
have : i < j := by
cases i; cases j
simp only [maxChild] at h
split at h
· split at h
· split at h <;> (cases h; simp_arith)
· cases h; simp_arith
· contradiction
if lt a[i] a[j] then
heapifyDown lt (a.swap i j) j
else a
termination_by sz - i

/-- Core operation for binary heaps, expressed directly on arrays.
Construct a heap from an unsorted array, by heapifying all the elements. -/
def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} :=
loop (a.size / 2) a (Nat.div_le_self ..)
def mkHeap (lt : α → α → Bool) (a : Vector α sz) : Vector α sz :=
loop (sz / 2) a (Nat.div_le_self ..)
where
/-- Inner loop for `mkHeap`. -/
loop : (i : Nat) → (a : Array α) → i ≤ a.size{a' : Array α // a'.size = a.size}
| 0, a, _ => ⟨a, rfl⟩
loop : (i : Nat) → (a : Vector α sz) → i ≤ szVector α sz
| 0, a, _ => a
| i+1, a, h =>
let h := Nat.lt_of_succ_le h
let a' := heapifyDown lt a ⟨i, h⟩
let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ Nat.le_of_lt h)
⟨a₂, h₂.trans a'.2

@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) :
(mkHeap lt a).1.size = a.size := (mkHeap lt a).2
let a' := heapifyDown lt a ⟨i, Nat.lt_of_succ_le h⟩
loop i a' (Nat.le_trans (Nat.le_succ _) h)

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/
def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
if i0 : i.1 = 0 then ⟨a, rfl⟩ else
have : (i.1 - 1) / 2 < i := Nat.lt_of_le_of_lt (Nat.div_le_self ..) <|
Nat.sub_lt (Nat.pos_of_ne_zero i0) Nat.zero_lt_one
let j := ⟨(i.1 - 1) / 2, Nat.lt_trans this i.2
if lt (a.get j) (a.get i) then
let a' := a.swap i j
let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2
⟨a₂, h₂.trans (a.size_swap i j)⟩
else ⟨a, rfl⟩

@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2
def heapifyUp (lt : α → α → Bool) (a : Vector α sz) (i : Fin sz) :
Vector α sz :=
match i with
| ⟨0, _⟩ => a
| ⟨i'+1, hi⟩ =>
let j := ⟨i'/2, by get_elem_tactic⟩
if lt a[j] a[i] then
heapifyUp lt (a.swap i j) j
else a

/-- `O(1)`. Build a new empty heap. -/
def empty (lt) : BinaryHeap α lt := ⟨#[]⟩
Expand All @@ -86,81 +84,91 @@ def singleton (lt) (x : α) : BinaryHeap α lt := ⟨#[x]⟩
/-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/
def size (self : BinaryHeap α lt) : Nat := self.1.size

/-- `O(1)`. Get data vector of a `BinaryHeap`. -/
def vector (self : BinaryHeap α lt) : Vector α self.size := ⟨self.1, rfl⟩

/-- `O(1)`. Get an element in the heap by index. -/
def get (self : BinaryHeap α lt) (i : Fin self.size) : α := self.1.get i

/-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/
def insert (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where
arr := let n := self.size;
heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩
arr := heapifyUp lt (self.vector.push x) ⟨_, Nat.lt_succ_self _⟩ |>.toArray

@[simp] theorem size_insert (self : BinaryHeap α lt) (x : α) :
(self.insert x).size = self.size + 1 := by
simp [insert, size, size_heapifyUp]
simp [size, insert]

/-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/
def max (self : BinaryHeap α lt) : Option α := self.1.get? 0

/-- Auxiliary for `popMax`. -/
def popMaxAux (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} :=
match e: self.1.size with
| 0 => ⟨self, by simp [size, e]⟩
| n+1 =>
have h0 := by rw [e]; apply Nat.succ_pos
have hn := by rw [e]; apply Nat.lt_succ_self
if hn0 : 0 < n then
let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop
⟨⟨heapifyDown lt a ⟨0, by rwa [Array.size_pop, Array.size_swap, e]⟩⟩,
by simp [size, a]⟩
else
⟨⟨self.1.pop⟩, by simp [size]⟩
def max (self : BinaryHeap α lt) : Option α := self.1[0]?

/-- `O(log n)`. Remove the maximum element from a `BinaryHeap`.
Call `max` first to actually retrieve the maximum element. -/
def popMax (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux
def popMax (self : BinaryHeap α lt) : BinaryHeap α lt :=
if h0 : self.size = 0 then self else
have hs : self.size - 1 < self.size := Nat.pred_lt h0
have h0 : 0 < self.size := Nat.zero_lt_of_ne_zero h0
let v := self.vector.swap ⟨_, h0⟩ ⟨_, hs⟩ |>.pop
if h : 0 < self.size - 1 then
⟨heapifyDown lt v ⟨0, h⟩ |>.toArray⟩
else
⟨v.toArray⟩

@[simp] theorem size_popMax (self : BinaryHeap α lt) :
self.popMax.size = self.size - 1 := self.popMaxAux.2
self.popMax.size = self.size - 1 := by
simp only [popMax, size]
split
· simp_arith [*]
· split <;> simp_arith [*]

/-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/
def extractMax (self : BinaryHeap α lt) : Option α × BinaryHeap α lt :=
(self.max, self.popMax)

theorem size_pos_of_max {self : BinaryHeap α lt} (e : self.max = some x) : 0 < self.size :=
Decidable.of_not_not fun h : ¬ 0 < self.1.size => by simp [BinaryHeap.max, Array.get?, h] at e
theorem size_pos_of_max {self : BinaryHeap α lt} (h : self.max = some x) : 0 < self.size := by
simp only [max, getElem?_def] at h
split at h
· assumption
· contradiction

/-- `O(log n)`. Equivalent to `extractMax (self.insert x)`, except that extraction cannot fail. -/
def insertExtractMax (self : BinaryHeap α lt) (x : α) : α × BinaryHeap α lt :=
match e: self.max with
match e : self.max with
| none => (x, self)
| some m =>
if lt x m then
let a := self.1.set ⟨0, size_pos_of_max e⟩ x
(m, ⟨heapifyDown lt a0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩)
let v := self.vector.set ⟨0, size_pos_of_max e⟩ x
(m, ⟨heapifyDown lt v0, size_pos_of_max e⟩ |>.toArray⟩)
else (x, self)

/-- `O(log n)`. Equivalent to `(self.max, self.popMax.insert x)`. -/
def replaceMax (self : BinaryHeap α lt) (x : α) : Option α × BinaryHeap α lt :=
match e: self.max with
| none => (none, ⟨self.1.push x⟩)
match e : self.max with
| none => (none, ⟨self.vector.push x |>.toArray⟩)
| some m =>
let a := self.1.set ⟨0, size_pos_of_max e⟩ x
(some m, ⟨heapifyDown lt a0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩)
let v := self.vector.set ⟨0, size_pos_of_max e⟩ x
(some m, ⟨heapifyDown lt v0, size_pos_of_max e⟩ |>.toArray⟩)

/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `x ≤ self.get i`. -/
def decreaseKey (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where
arr := heapifyDown lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2
arr := heapifyDown lt (self.vector.set i x) i |>.toArray

/-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `self.get i ≤ x`. -/
def increaseKey (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where
arr := heapifyUp lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2
arr := heapifyUp lt (self.vector.set i x) i |>.toArray

end Batteries.BinaryHeap

/-- `O(n)`. Convert an unsorted vector to a `BinaryHeap`. -/
def Batteries.Vector.toBinaryHeap (lt : α → α → Bool) (v : Vector α n) :
Batteries.BinaryHeap α lt where
arr := BinaryHeap.mkHeap lt v |>.toArray

open Batteries in
/-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/
def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : Batteries.BinaryHeap α lt where
arr := Batteries.BinaryHeap.mkHeap lt a
arr := BinaryHeap.mkHeap lt ⟨a, rfl⟩ |>.toArray

open Batteries in
/-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/
@[specialize] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α :=
loop (a.toBinaryHeap (flip lt)) #[]
Expand Down

0 comments on commit 1bfaec4

Please sign in to comment.