From 0a4e560181862aca9ec2af40c2eaa74e04ca4d7a Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Tue, 17 Nov 2020 12:23:34 -0700 Subject: [PATCH] Refactor builders This commit fixes a bug in builders that caused changes to leak when building on existing collections. The builders now only work on empty list & map types and are marked as invalid after the underlying collection has been retrieved from them. --- README.md | 8 ++-- immutable.go | 118 ++++++++++++++++++++++++++-------------------- immutable_test.go | 89 ++++++++++++++-------------------- 3 files changed, 107 insertions(+), 108 deletions(-) diff --git a/README.md b/README.md index ff568d7..5292549 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ a list in-place until you are ready to use it. This can improve bulk list building by 10x or more. ```go -b := immutable.NewListBuilder(immutable.NewList()) +b := immutable.NewListBuilder() b.Append("foo") b.Append("bar") b.Set(2, "baz") @@ -125,8 +125,7 @@ fmt.Println(l.Get(0)) // "foo" fmt.Println(l.Get(1)) // "baz" ``` -Lists are safe to use even after you continue to use the builder. It is also -safe to build on top of existing lists. +Builders are invalid after the call to `List()`. ## Map @@ -230,8 +229,7 @@ fmt.Println(m.Get("foo")) // "300" fmt.Println(m.Get("bar")) // "200" ``` -Maps are safe to use even after you continue to use the builder. You can -also build on top of existing maps too. +Builders are invalid after the call to `Map()`. ### Implementing a custom Hasher diff --git a/immutable.go b/immutable.go index 5adb01a..c8da568 100644 --- a/immutable.go +++ b/immutable.go @@ -224,37 +224,35 @@ func (l *List) Iterator() *ListIterator { return itr } -// ListBuilder represents an efficient builder for creating Lists. -// -// Lists returned from the builder are safe to use even after you continue to -// use the builder. However, for efficiency, you should only retrieve your list -// after you have completed building it. +// ListBuilder represents an efficient builder for creating new Lists. type ListBuilder struct { - list *List // current state - mutable bool // if true, next mutation will operate in-place. + list *List // current state } -// NewListBuilder returns a new instance of ListBuilder to build on a base list. -func NewListBuilder(list *List) *ListBuilder { - return &ListBuilder{list: list} +// NewListBuilder returns a new instance of ListBuilder. +func NewListBuilder() *ListBuilder { + return &ListBuilder{list: NewList()} } // List returns the current copy of the list. -// The returned list is safe to use even if after the builder continues to be used. +// The builder should not be used again after the list after this call. func (b *ListBuilder) List() *List { + assert(b.list != nil, "immutable.ListBuilder.List(): duplicate call to fetch list") list := b.list - b.mutable = false + b.list = nil return list } // Len returns the number of elements in the underlying list. func (b *ListBuilder) Len() int { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") return b.list.Len() } // Get returns the value at the given index. Similar to slices, this method will // panic if index is below zero or is greater than or equal to the list size. func (b *ListBuilder) Get(index int) interface{} { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") return b.list.Get(index) } @@ -262,27 +260,33 @@ func (b *ListBuilder) Get(index int) interface{} { // panic if index is below zero or if the index is greater than or equal to the // list size. func (b *ListBuilder) Set(index int, value interface{}) { - b.list = b.list.set(index, value, b.mutable) - b.mutable = true + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.set(index, value, true) } // Append adds value to the end of the list. func (b *ListBuilder) Append(value interface{}) { - b.list = b.list.append(value, b.mutable) - b.mutable = true + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.append(value, true) } // Prepend adds value to the beginning of the list. func (b *ListBuilder) Prepend(value interface{}) { - b.list = b.list.prepend(value, b.mutable) - b.mutable = true + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.prepend(value, true) } // Slice updates the list with a sublist of elements between start and end index. // See List.Slice() for more details. func (b *ListBuilder) Slice(start, end int) { - b.list = b.list.slice(start, end, b.mutable) - b.mutable = true + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + b.list = b.list.slice(start, end, true) +} + +// Iterator returns a new iterator for the underlying list. +func (b *ListBuilder) Iterator() *ListIterator { + assert(b.list != nil, "immutable.ListBuilder: builder invalid after List() invocation") + return b.list.Iterator() } // Constants for bit shifts used for levels in the List trie. @@ -788,48 +792,52 @@ func (m *Map) Iterator() *MapIterator { } // MapBuilder represents an efficient builder for creating Maps. -// -// Maps returned from the builder are safe to use even after you continue to -// use the builder. However, for efficiency, you should only retrieve your map -// after you have completed building it. type MapBuilder struct { - m *Map // current state - mutable bool // if true, next mutation will operate in-place. + m *Map // current state } -// NewMapBuilder returns a new instance of MapBuilder to build on a base map. -func NewMapBuilder(m *Map) *MapBuilder { - return &MapBuilder{m: m} +// NewMapBuilder returns a new instance of MapBuilder. +func NewMapBuilder(hasher Hasher) *MapBuilder { + return &MapBuilder{m: NewMap(hasher)} } -// Map returns the current copy of the map. -// The returned map is safe to use even if after the builder continues to be used. +// Map returns the underlying map. Only call once. +// Builder is invalid after call. Will panic on second invocation. func (b *MapBuilder) Map() *Map { + assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") m := b.m - b.mutable = false + b.m = nil return m } // Len returns the number of elements in the underlying map. func (b *MapBuilder) Len() int { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") return b.m.Len() } // Get returns the value for the given key. func (b *MapBuilder) Get(key interface{}) (value interface{}, ok bool) { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") return b.m.Get(key) } // Set sets the value of the given key. See Map.Set() for additional details. func (b *MapBuilder) Set(key, value interface{}) { - b.m = b.m.set(key, value, b.mutable) - b.mutable = true + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + b.m = b.m.set(key, value, true) } // Delete removes the given key. See Map.Delete() for additional details. func (b *MapBuilder) Delete(key interface{}) { - b.m = b.m.delete(key, b.mutable) - b.mutable = true + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + b.m = b.m.delete(key, true) +} + +// Iterator returns a new iterator for the underlying map. +func (b *MapBuilder) Iterator() *MapIterator { + assert(b.m != nil, "immutable.MapBuilder: builder invalid after Map() invocation") + return b.m.Iterator() } // mapNode represents any node in the map tree. @@ -1660,48 +1668,52 @@ func (m *SortedMap) Iterator() *SortedMapIterator { } // SortedMapBuilder represents an efficient builder for creating sorted maps. -// -// Maps returned from the builder are safe to use even after you continue to -// use the builder. However, for efficiency, you should only retrieve your map -// after you have completed building it. type SortedMapBuilder struct { - m *SortedMap // current state - mutable bool // if true, next mutation will operate in-place. + m *SortedMap // current state } -// NewSortedMapBuilder returns a new instance of SortedMapBuilder to build on a base map. -func NewSortedMapBuilder(m *SortedMap) *SortedMapBuilder { - return &SortedMapBuilder{m: m} +// NewSortedMapBuilder returns a new instance of SortedMapBuilder. +func NewSortedMapBuilder(comparer Comparer) *SortedMapBuilder { + return &SortedMapBuilder{m: NewSortedMap(comparer)} } // SortedMap returns the current copy of the map. // The returned map is safe to use even if after the builder continues to be used. func (b *SortedMapBuilder) Map() *SortedMap { + assert(b.m != nil, "immutable.SortedMapBuilder.Map(): duplicate call to fetch map") m := b.m - b.mutable = false + b.m = nil return m } // Len returns the number of elements in the underlying map. func (b *SortedMapBuilder) Len() int { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") return b.m.Len() } // Get returns the value for the given key. func (b *SortedMapBuilder) Get(key interface{}) (value interface{}, ok bool) { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") return b.m.Get(key) } // Set sets the value of the given key. See SortedMap.Set() for additional details. func (b *SortedMapBuilder) Set(key, value interface{}) { - b.m = b.m.set(key, value, b.mutable) - b.mutable = true + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + b.m = b.m.set(key, value, true) } // Delete removes the given key. See SortedMap.Delete() for additional details. func (b *SortedMapBuilder) Delete(key interface{}) { - b.m = b.m.delete(key, b.mutable) - b.mutable = true + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + b.m = b.m.delete(key, true) +} + +// Iterator returns a new iterator for the underlying map positioned at the first key. +func (b *SortedMapBuilder) Iterator() *SortedMapIterator { + assert(b.m != nil, "immutable.SortedMapBuilder: builder invalid after Map() invocation") + return b.m.Iterator() } // sortedMapNode represents a branch or leaf node in the sorted map. @@ -2714,3 +2726,9 @@ type reflectStringComparer struct{} func (c *reflectStringComparer) Compare(a, b interface{}) int { return strings.Compare(reflect.ValueOf(a).String(), reflect.ValueOf(b).String()) } + +func assert(condition bool, message string) { + if !condition { + panic(message) + } +} diff --git a/immutable_test.go b/immutable_test.go index 9275f1d..130904f 100644 --- a/immutable_test.go +++ b/immutable_test.go @@ -213,7 +213,7 @@ type TList struct { func NewTList() *TList { return &TList{ im: NewList(), - builder: NewListBuilder(NewList()), + builder: NewListBuilder(), } } @@ -280,33 +280,29 @@ func (l *TList) Validate() error { return fmt.Errorf("Len()=%v, expected %d", got, exp) } - bl := l.builder.List() for i := range l.std { if got, exp := l.im.Get(i), l.std[i]; got != exp { return fmt.Errorf("Get(%d)=%v, expected %v", i, got, exp) } else if got, exp := l.builder.Get(i), l.std[i]; got != exp { - return fmt.Errorf("Builder.Get(%d)=%v, expected %v", i, got, exp) - } else if got, exp := bl.Get(i), l.std[i]; got != exp { return fmt.Errorf("Builder.List/Get(%d)=%v, expected %v", i, got, exp) } } - if err := l.validateForwardIterator("basic", l.im); err != nil { + if err := l.validateForwardIterator("basic", l.im.Iterator()); err != nil { return err - } else if err := l.validateBackwardIterator("basic", l.im); err != nil { + } else if err := l.validateBackwardIterator("basic", l.im.Iterator()); err != nil { return err } - if err := l.validateForwardIterator("builder", bl); err != nil { + if err := l.validateForwardIterator("builder", l.builder.Iterator()); err != nil { return err - } else if err := l.validateBackwardIterator("builder", bl); err != nil { + } else if err := l.validateBackwardIterator("builder", l.builder.Iterator()); err != nil { return err } return nil } -func (l *TList) validateForwardIterator(typ string, list *List) error { - itr := list.Iterator() +func (l *TList) validateForwardIterator(typ string, itr *ListIterator) error { for i := range l.std { if j, v := itr.Next(); i != j || l.std[i] != v { return fmt.Errorf("ListIterator.Next()=<%v,%v>, expected <%v,%v> [%s]", j, v, i, l.std[i], typ) @@ -323,8 +319,7 @@ func (l *TList) validateForwardIterator(typ string, list *List) error { return nil } -func (l *TList) validateBackwardIterator(typ string, list *List) error { - itr := list.Iterator() +func (l *TList) validateBackwardIterator(typ string, itr *ListIterator) error { itr.Last() for i := len(l.std) - 1; i >= 0; i-- { if j, v := itr.Prev(); i != j || l.std[i] != v { @@ -422,7 +417,7 @@ func BenchmarkBuiltinSlice_Append(b *testing.B) { func BenchmarkListBuilder_Append(b *testing.B) { b.ReportAllocs() - builder := NewListBuilder(NewList()) + builder := NewListBuilder() for i := 0; i < b.N; i++ { builder.Append(i) } @@ -430,7 +425,7 @@ func BenchmarkListBuilder_Append(b *testing.B) { func BenchmarkListBuilder_Prepend(b *testing.B) { b.ReportAllocs() - builder := NewListBuilder(NewList()) + builder := NewListBuilder() for i := 0; i < b.N; i++ { builder.Prepend(i) } @@ -439,7 +434,7 @@ func BenchmarkListBuilder_Prepend(b *testing.B) { func BenchmarkListBuilder_Set(b *testing.B) { const n = 10000 - builder := NewListBuilder(NewList()) + builder := NewListBuilder() for i := 0; i < 10000; i++ { builder.Append(i) } @@ -544,7 +539,7 @@ func ExampleList_Iterator_reverse() { } func ExampleListBuilder_Append() { - b := NewListBuilder(NewList()) + b := NewListBuilder() b.Append("foo") b.Append("bar") b.Append("baz") @@ -560,7 +555,7 @@ func ExampleListBuilder_Append() { } func ExampleListBuilder_Prepend() { - b := NewListBuilder(NewList()) + b := NewListBuilder() b.Prepend("foo") b.Prepend("bar") b.Prepend("baz") @@ -576,7 +571,7 @@ func ExampleListBuilder_Prepend() { } func ExampleListBuilder_Set() { - b := NewListBuilder(NewList()) + b := NewListBuilder() b.Append("foo") b.Append("bar") b.Set(1, "baz") @@ -590,7 +585,7 @@ func ExampleListBuilder_Set() { } func ExampleListBuilder_Slice() { - b := NewListBuilder(NewList()) + b := NewListBuilder() b.Append("foo") b.Append("bar") b.Append("baz") @@ -1126,9 +1121,6 @@ func TestMap_Delete(t *testing.T) { m.Set(m.NewKey(rand), rand.Intn(10000)) } } - if err := m.Validate(); err != nil { - t.Fatal(err) - } // Delete all and verify they are gone. keys := make([]int, len(m.keys)) @@ -1207,7 +1199,7 @@ func TestMap_LimitedHash(t *testing.T) { hash: func(value interface{}) uint32 { return hashUint64(uint64(value.(int))) % 0xFF }, equal: func(a, b interface{}) bool { return a.(int) == b.(int) }, } - b := NewMapBuilder(NewMap(&h)) + b := NewMapBuilder(&h) rand := rand.New(rand.NewSource(0)) keys := rand.Perm(100000) @@ -1229,7 +1221,7 @@ func TestMap_LimitedHash(t *testing.T) { } // Verify iteration. - itr := b.Map().Iterator() + itr := b.Iterator() for !itr.Done() { if k, v := itr.Next(); v != k.(int)*2 { t.Fatalf("MapIterator.Next()=<%v,%v>, expected value %v", k, v, k.(int)*2) @@ -1241,12 +1233,6 @@ func TestMap_LimitedHash(t *testing.T) { t.Fatal("expected no value") } - // Verify delete non-existent key works. - m := b.Map() - if b.Delete(10000000 + 1); m != b.Map() { - t.Fatal("expected no change") - } - // Remove all keys. for _, key := range keys { b.Delete(key) @@ -1268,7 +1254,7 @@ type TMap struct { func NewTestMap() *TMap { return &TMap{ im: NewMap(nil), - builder: NewMapBuilder(NewMap(nil)), + builder: NewMapBuilder(nil), std: make(map[int]int), } } @@ -1328,17 +1314,16 @@ func (m *TMap) Validate() error { return fmt.Errorf("builder key (%d) mismatch: immutable=%d, std=%d", k, v, m.std[k]) } } - if err := m.validateIterator(m.im); err != nil { + if err := m.validateIterator(m.im.Iterator()); err != nil { return fmt.Errorf("basic: %s", err) - } else if err := m.validateIterator(m.builder.Map()); err != nil { + } else if err := m.validateIterator(m.builder.Iterator()); err != nil { return fmt.Errorf("builder: %s", err) } return nil } -func (m *TMap) validateIterator(mm *Map) error { +func (m *TMap) validateIterator(itr *MapIterator) error { other := make(map[int]int) - itr := mm.Iterator() for !itr.Done() { k, v := itr.Next() other[k.(int)] = v.(int) @@ -1391,7 +1376,7 @@ func BenchmarkMap_Set(b *testing.B) { func BenchmarkMap_Delete(b *testing.B) { const n = 10000000 - builder := NewMapBuilder(NewMap(nil)) + builder := NewMapBuilder(nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -1426,7 +1411,7 @@ func BenchmarkMap_Iterator(b *testing.B) { func BenchmarkMapBuilder_Set(b *testing.B) { b.ReportAllocs() - builder := NewMapBuilder(NewMap(nil)) + builder := NewMapBuilder(nil) for i := 0; i < b.N; i++ { builder.Set(i, i) } @@ -1435,7 +1420,7 @@ func BenchmarkMapBuilder_Set(b *testing.B) { func BenchmarkMapBuilder_Delete(b *testing.B) { const n = 10000000 - builder := NewMapBuilder(NewMap(nil)) + builder := NewMapBuilder(nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -1512,7 +1497,7 @@ func ExampleMap_Iterator() { } func ExampleMapBuilder_Set() { - b := NewMapBuilder(NewMap(nil)) + b := NewMapBuilder(nil) b.Set("foo", "bar") b.Set("baz", 100) @@ -1532,7 +1517,7 @@ func ExampleMapBuilder_Set() { } func ExampleMapBuilder_Delete() { - b := NewMapBuilder(NewMap(nil)) + b := NewMapBuilder(nil) b.Set("foo", "bar") b.Set("baz", 100) b.Delete("baz") @@ -2170,7 +2155,7 @@ type TSortedMap struct { func NewTSortedMap() *TSortedMap { return &TSortedMap{ im: NewSortedMap(nil), - builder: NewSortedMapBuilder(NewSortedMap(nil)), + builder: NewSortedMapBuilder(nil), std: make(map[int]int), } } @@ -2236,22 +2221,21 @@ func (m *TSortedMap) Validate() error { } sort.Ints(m.keys) - if err := m.validateForwardIterator(m.im); err != nil { + if err := m.validateForwardIterator(m.im.Iterator()); err != nil { return fmt.Errorf("basic: %s", err) - } else if err := m.validateBackwardIterator(m.im); err != nil { + } else if err := m.validateBackwardIterator(m.im.Iterator()); err != nil { return fmt.Errorf("basic: %s", err) } - if err := m.validateForwardIterator(m.builder.Map()); err != nil { + if err := m.validateForwardIterator(m.builder.Iterator()); err != nil { return fmt.Errorf("basic: %s", err) - } else if err := m.validateBackwardIterator(m.builder.Map()); err != nil { + } else if err := m.validateBackwardIterator(m.builder.Iterator()); err != nil { return fmt.Errorf("basic: %s", err) } return nil } -func (m *TSortedMap) validateForwardIterator(mm *SortedMap) error { - itr := mm.Iterator() +func (m *TSortedMap) validateForwardIterator(itr *SortedMapIterator) error { for i, k0 := range m.keys { v0 := m.std[k0] if k1, v1 := itr.Next(); k0 != k1 || v0 != v1 { @@ -2269,8 +2253,7 @@ func (m *TSortedMap) validateForwardIterator(mm *SortedMap) error { return nil } -func (m *TSortedMap) validateBackwardIterator(mm *SortedMap) error { - itr := mm.Iterator() +func (m *TSortedMap) validateBackwardIterator(itr *SortedMapIterator) error { itr.Last() for i := len(m.keys) - 1; i >= 0; i-- { k0 := m.keys[i] @@ -2345,7 +2328,7 @@ func BenchmarkSortedMap_Iterator(b *testing.B) { func BenchmarkSortedMapBuilder_Set(b *testing.B) { b.ReportAllocs() - builder := NewSortedMapBuilder(NewSortedMap(nil)) + builder := NewSortedMapBuilder(nil) for i := 0; i < b.N; i++ { builder.Set(i, i) } @@ -2354,7 +2337,7 @@ func BenchmarkSortedMapBuilder_Set(b *testing.B) { func BenchmarkSortedMapBuilder_Delete(b *testing.B) { const n = 1000000 - builder := NewSortedMapBuilder(NewSortedMap(nil)) + builder := NewSortedMapBuilder(nil) for i := 0; i < n; i++ { builder.Set(i, i) } @@ -2431,7 +2414,7 @@ func ExampleSortedMap_Iterator() { } func ExampleSortedMapBuilder_Set() { - b := NewSortedMapBuilder(NewSortedMap(nil)) + b := NewSortedMapBuilder(nil) b.Set("foo", "bar") b.Set("baz", 100) @@ -2451,7 +2434,7 @@ func ExampleSortedMapBuilder_Set() { } func ExampleSortedMapBuilder_Delete() { - b := NewSortedMapBuilder(NewSortedMap(nil)) + b := NewSortedMapBuilder(nil) b.Set("foo", "bar") b.Set("baz", 100) b.Delete("baz")