diff --git a/internal/unionstore/art/art_arena.go b/internal/unionstore/art/art_arena.go index f196bf0cb1..d977d91ab4 100644 --- a/internal/unionstore/art/art_arena.go +++ b/internal/unionstore/art/art_arena.go @@ -15,6 +15,7 @@ package art import ( + "sync/atomic" "unsafe" "github.com/tikv/client-go/v2/internal/unionstore/arena" @@ -25,11 +26,23 @@ import ( // reusing blocks reduces the memory pieces. type nodeArena struct { arena.MemdbArena + // The ART node will expand to a higher capacity, and the address of the freed node will be stored in the free list for reuse. // By reusing the freed node, memory usage and fragmentation can be reduced. freeNode4 []arena.MemdbArenaAddr freeNode16 []arena.MemdbArenaAddr freeNode48 []arena.MemdbArenaAddr + + // When there is ongoing snapshot iterator, ART should keep the old versions available, + // reuse the node can cause incorrect snapshot read result in this time. + // To avoid reused, freed nodes will be stored in unused slices before the snapshot iterator is closed. + // blockedSnapshotCnt is used to count the ongoing snapshot iterator. + blockedSnapshotCnt atomic.Int64 + // isUnusedNodeFreeing protect the free unused nodes process from data race. + isUnusedNodeFreeing atomic.Bool + unusedNode4 []arena.MemdbArenaAddr + unusedNode16 []arena.MemdbArenaAddr + unusedNode48 []arena.MemdbArenaAddr } type artAllocator struct { @@ -62,7 +75,11 @@ func (f *artAllocator) allocNode4() (arena.MemdbArenaAddr, *node4) { } func (f *artAllocator) freeNode4(addr arena.MemdbArenaAddr) { - f.nodeAllocator.freeNode4 = append(f.nodeAllocator.freeNode4, addr) + if f.nodeAllocator.blockedSnapshotCnt.Load() == 0 { + f.nodeAllocator.freeNode4 = append(f.nodeAllocator.freeNode4, addr) + return + } + f.nodeAllocator.unusedNode4 = append(f.nodeAllocator.unusedNode4, addr) } func (f *artAllocator) getNode4(addr arena.MemdbArenaAddr) *node4 { @@ -88,7 +105,11 @@ func (f *artAllocator) allocNode16() (arena.MemdbArenaAddr, *node16) { } func (f *artAllocator) freeNode16(addr arena.MemdbArenaAddr) { - f.nodeAllocator.freeNode16 = append(f.nodeAllocator.freeNode16, addr) + if f.nodeAllocator.blockedSnapshotCnt.Load() == 0 { + f.nodeAllocator.freeNode16 = append(f.nodeAllocator.freeNode16, addr) + return + } + f.nodeAllocator.unusedNode16 = append(f.nodeAllocator.unusedNode16, addr) } func (f *artAllocator) getNode16(addr arena.MemdbArenaAddr) *node16 { @@ -114,7 +135,11 @@ func (f *artAllocator) allocNode48() (arena.MemdbArenaAddr, *node48) { } func (f *artAllocator) freeNode48(addr arena.MemdbArenaAddr) { - f.nodeAllocator.freeNode48 = append(f.nodeAllocator.freeNode48, addr) + if f.nodeAllocator.blockedSnapshotCnt.Load() == 0 { + f.nodeAllocator.freeNode48 = append(f.nodeAllocator.freeNode48, addr) + return + } + f.nodeAllocator.unusedNode48 = append(f.nodeAllocator.unusedNode48, addr) } func (f *artAllocator) getNode48(addr arena.MemdbArenaAddr) *node48 { @@ -156,3 +181,31 @@ func (f *artAllocator) getLeaf(addr arena.MemdbArenaAddr) *artLeaf { data := f.nodeAllocator.GetData(addr) return (*artLeaf)(unsafe.Pointer(&data[0])) } + +func (f *artAllocator) snapshotInc() { + f.nodeAllocator.blockedSnapshotCnt.Add(1) +} + +// freeUnusedNodes will move the unused old version nodes into free list, allow it to be reused. +// This function is called when the snapshot iterator is closed, because read iterators can run concurrently. +func (f *artAllocator) snapshotDec() { + if f.nodeAllocator.blockedSnapshotCnt.Add(-1) != 0 { + return + } + if !f.nodeAllocator.isUnusedNodeFreeing.CompareAndSwap(false, true) { + return + } + if len(f.nodeAllocator.unusedNode4) > 0 { + f.nodeAllocator.freeNode4 = append(f.nodeAllocator.freeNode4, f.nodeAllocator.unusedNode4...) + f.nodeAllocator.unusedNode4 = f.nodeAllocator.unusedNode4[:0] + } + if len(f.nodeAllocator.unusedNode16) > 0 { + f.nodeAllocator.freeNode16 = append(f.nodeAllocator.freeNode16, f.nodeAllocator.unusedNode16...) + f.nodeAllocator.unusedNode16 = f.nodeAllocator.unusedNode16[:0] + } + if len(f.nodeAllocator.unusedNode48) > 0 { + f.nodeAllocator.freeNode48 = append(f.nodeAllocator.freeNode48, f.nodeAllocator.unusedNode48...) + f.nodeAllocator.unusedNode48 = f.nodeAllocator.unusedNode48[:0] + } + f.nodeAllocator.isUnusedNodeFreeing.Store(false) +} diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go index 6f470dda43..e9cb4fba17 100644 --- a/internal/unionstore/art/art_snapshot.go +++ b/internal/unionstore/art/art_snapshot.go @@ -43,12 +43,16 @@ func (t *ART) SnapshotIter(start, end []byte) *SnapIter { panic(err) } it := &SnapIter{ - Iterator: inner, - cp: t.getSnapshot(), + Iterator: inner, + cp: t.getSnapshot(), + isCounting: false, } for !it.setValue() && it.Valid() { _ = it.Next() } + if it.Valid() { + it.counterInc() + } return it } @@ -62,9 +66,12 @@ func (t *ART) SnapshotIterReverse(k, lowerBound []byte) *SnapIter { Iterator: inner, cp: t.getSnapshot(), } - for !it.setValue() && it.valid { + for !it.setValue() && it.Valid() { _ = it.Next() } + if it.Valid() { + it.counterInc() + } return it } @@ -91,8 +98,9 @@ func (snap *SnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) { type SnapIter struct { *Iterator - value []byte - cp arena.MemDBCheckpoint + value []byte + cp arena.MemDBCheckpoint + isCounting bool } func (i *SnapIter) Value() []byte { @@ -100,6 +108,11 @@ func (i *SnapIter) Value() []byte { } func (i *SnapIter) Next() error { + defer func() { + if !i.Valid() { + i.counterDec() + } + }() i.value = nil for i.Valid() { if err := i.Iterator.Next(); err != nil { @@ -112,6 +125,11 @@ func (i *SnapIter) Next() error { return nil } +func (i *SnapIter) Close() { + i.Iterator.Close() + i.counterDec() +} + func (i *SnapIter) setValue() bool { if !i.Valid() { return false @@ -122,3 +140,16 @@ func (i *SnapIter) setValue() bool { } return false } + +func (i *SnapIter) counterInc() { + i.isCounting = true + i.tree.allocator.snapshotInc() +} + +func (i *SnapIter) counterDec() { + if !i.isCounting { + return + } + i.isCounting = false + i.tree.allocator.snapshotDec() +} diff --git a/internal/unionstore/art/art_snapshot_test.go b/internal/unionstore/art/art_snapshot_test.go new file mode 100644 index 0000000000..85a871069c --- /dev/null +++ b/internal/unionstore/art/art_snapshot_test.go @@ -0,0 +1,90 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package art + +import ( + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tikv/client-go/v2/internal/unionstore/arena" +) + +func TestSnapshotIteratorPreventFreeNode(t *testing.T) { + check := func(num int) { + tree := New() + for i := 0; i < num; i++ { + tree.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + var unusedNodeSlice *[]arena.MemdbArenaAddr + switch num { + case 4: + unusedNodeSlice = &tree.allocator.nodeAllocator.unusedNode4 + case 16: + unusedNodeSlice = &tree.allocator.nodeAllocator.unusedNode16 + case 48: + unusedNodeSlice = &tree.allocator.nodeAllocator.unusedNode48 + default: + panic("unsupported num") + } + it := tree.SnapshotIter(nil, nil) + require.Equal(t, 0, len(*unusedNodeSlice)) + tree.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + require.Equal(t, 1, len(*unusedNodeSlice)) + it.Close() + require.Equal(t, 0, len(*unusedNodeSlice)) + } + + check(4) + check(16) + check(48) +} + +func TestConcurrentSnapshotIterNoRace(t *testing.T) { + check := func(num int) { + tree := New() + for i := 0; i < num; i++ { + tree.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } + + const concurrency = 100 + it := tree.SnapshotIter(nil, nil) + + tree.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) + + var wg sync.WaitGroup + wg.Add(concurrency) + go func() { + it.Close() + wg.Done() + }() + for i := 1; i < concurrency; i++ { + go func(it *SnapIter) { + concurrentIt := tree.SnapshotIter(nil, nil) + concurrentIt.Close() + wg.Done() + }(it) + } + wg.Wait() + + require.Empty(t, tree.allocator.nodeAllocator.unusedNode4) + require.Empty(t, tree.allocator.nodeAllocator.unusedNode16) + require.Empty(t, tree.allocator.nodeAllocator.unusedNode48) + } + + check(4) + check(16) + check(48) +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index c10ee0c8dc..e65357247b 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -1329,36 +1329,37 @@ func TestSelectValueHistory(t *testing.T) { } func TestSnapshotReaderWithWrite(t *testing.T) { - check := func(db MemBuffer) { - db.Set([]byte{0, 1}, []byte{0, 1}) - db.Set([]byte{0, 2}, []byte{0, 2}) - db.Set([]byte{0, 3}, []byte{0, 3}) - db.Set([]byte{0, 4}, []byte{0, 4}) + check := func(db MemBuffer, num int) { + for i := 0; i < num; i++ { + db.Set([]byte{0, byte(i)}, []byte{0, byte(i)}) + } h := db.Staging() defer db.Release(h) - iter := db.SnapshotIter([]byte{0, 1}, []byte{0, 255}) - assert.Equal(t, iter.Key(), []byte{0, 1}) - - db.Set([]byte{0, 5}, []byte{0, 5}) // node4 is freed and wait to be reused. - - // reuse the node4 - db.Set([]byte{1, 1}, []byte{1, 1}) - db.Set([]byte{1, 2}, []byte{1, 2}) - db.Set([]byte{1, 3}, []byte{1, 3}) - db.Set([]byte{1, 4}, []byte{1, 4}) - - assert.Equal(t, iter.Key(), []byte{0, 1}) - iter.Next() - assert.Equal(t, iter.Key(), []byte{0, 2}) - iter.Next() - assert.Equal(t, iter.Key(), []byte{0, 3}) - iter.Next() - assert.Equal(t, iter.Key(), []byte{0, 4}) - iter.Next() + iter := db.SnapshotIter([]byte{0, 0}, []byte{0, 255}) + assert.Equal(t, iter.Key(), []byte{0, 0}) + + db.Set([]byte{0, byte(num)}, []byte{0, byte(num)}) // ART: node4/node16/node48 is freed and wait to be reused. + + // ART: reuse the node4/node16/node48 + for i := 0; i < num; i++ { + db.Set([]byte{1, byte(i)}, []byte{1, byte(i)}) + } + + for i := 0; i < num; i++ { + assert.True(t, iter.Valid()) + assert.Equal(t, iter.Key(), []byte{0, byte(i)}) + assert.Nil(t, iter.Next()) + } assert.False(t, iter.Valid()) } - check(newRbtDBWithContext()) - check(newArtDBWithContext()) + check(newRbtDBWithContext(), 4) + check(newArtDBWithContext(), 4) + + check(newRbtDBWithContext(), 16) + check(newArtDBWithContext(), 16) + + check(newRbtDBWithContext(), 48) + check(newArtDBWithContext(), 48) }