From 36cecf612550b13104a9c61165fba5e0e175bd26 Mon Sep 17 00:00:00 2001 From: Peter Ebden Date: Mon, 6 Nov 2023 16:46:16 +0000 Subject: [PATCH] Convert cmap shards to an rwmutex (#2944) * Change cmap to an rwmutex * Add benchmark --- src/cmap/BUILD | 2 ++ src/cmap/cmap.go | 13 ++++++++++--- src/cmap/cmap_test.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/src/cmap/BUILD b/src/cmap/BUILD index 1176510a31..0f7d583e7f 100644 --- a/src/cmap/BUILD +++ b/src/cmap/BUILD @@ -19,6 +19,7 @@ go_test( ":cmap", "///third_party/go/github.com_cespare_xxhash_v2//:v2", "///third_party/go/github.com_stretchr_testify//assert", + "///third_party/go/golang.org_x_sync//errgroup", ], ) @@ -28,6 +29,7 @@ go_benchmark( deps = [ ":cmap", "///third_party/go/github.com_stretchr_testify//assert", + "///third_party/go/golang.org_x_sync//errgroup", ], ) diff --git a/src/cmap/cmap.go b/src/cmap/cmap.go index 3485cbf884..3f6cfda767 100644 --- a/src/cmap/cmap.go +++ b/src/cmap/cmap.go @@ -101,7 +101,7 @@ type awaitableValue[V any] struct { // A shard is one of the individual shards of a map. type shard[K comparable, V any] struct { m map[K]awaitableValue[V] - l sync.Mutex + l sync.RWMutex } // Set is the equivalent of `map[key] = val`. @@ -136,6 +136,13 @@ func (s *shard[K, V]) Set(key K, val V, overwrite bool) (V, bool) { // Exactly one of the target or channel will be returned. // The third value is true if it is the first call that is waiting on this value. func (s *shard[K, V]) Get(key K) (val V, wait <-chan struct{}, first bool) { + s.l.RLock() + if v, ok := s.m[key]; ok { + s.l.RUnlock() + return v.Val, v.Wait, false + } + s.l.RUnlock() + s.l.Lock() defer s.l.Unlock() if v, ok := s.m[key]; ok { @@ -150,8 +157,8 @@ func (s *shard[K, V]) Get(key K) (val V, wait <-chan struct{}, first bool) { // Values returns a copy of all the targets currently in the map. func (s *shard[K, V]) Values() []V { - s.l.Lock() - defer s.l.Unlock() + s.l.RLock() + defer s.l.RUnlock() ret := make([]V, 0, len(s.m)) for _, v := range s.m { if v.Wait == nil { diff --git a/src/cmap/cmap_test.go b/src/cmap/cmap_test.go index fb6a3cf5b5..86b949a175 100644 --- a/src/cmap/cmap_test.go +++ b/src/cmap/cmap_test.go @@ -1,11 +1,13 @@ package cmap import ( + "fmt" "sort" "strconv" "testing" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" ) func hashInts(k int) uint64 { @@ -84,3 +86,34 @@ func TestResize(t *testing.T) { }) } } + +func BenchmarkMapInsertsAndGets(b *testing.B) { + // Attempts to mimic a vaguely realistic blend of writes and (more) reads. + m := New[int, int](DefaultShardCount, hashInts) + var wg, rg errgroup.Group + wg.SetLimit(3) + rg.SetLimit(12) + for i := 0; i < b.N; i++ { + x := i + for j := 0; j < 10; j++ { + wg.Go(func() error { + for k := 0; k < 1000; k++ { + m.Set(x, x) + } + return nil + }) + } + for j := 0; j < 100; j++ { + rg.Go(func() error { + for k := 0; k < 1000; k++ { + if y := m.Get(x); y != x && y != 0 { + return fmt.Errorf("incorrect result, was %d, should be %d", y, x) + } + } + return nil + }) + } + } + assert.NoError(b, wg.Wait()) + assert.NoError(b, rg.Wait()) +}