Skip to content

Commit

Permalink
Convert cmap shards to an rwmutex (#2944)
Browse files Browse the repository at this point in the history
* Change cmap to an rwmutex

* Add benchmark
  • Loading branch information
peterebden authored Nov 6, 2023
1 parent e1028b9 commit 36cecf6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 2 additions & 0 deletions src/cmap/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ go_test(
":cmap",
"///third_party/go/github.com_cespare_xxhash_v2//:v2",
"///third_party/go/github.com_stretchr_testify//assert",
"///third_party/go/golang.org_x_sync//errgroup",
],
)

Expand All @@ -28,6 +29,7 @@ go_benchmark(
deps = [
":cmap",
"///third_party/go/github.com_stretchr_testify//assert",
"///third_party/go/golang.org_x_sync//errgroup",
],
)

Expand Down
13 changes: 10 additions & 3 deletions src/cmap/cmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ type awaitableValue[V any] struct {
// A shard is one of the individual shards of a map.
type shard[K comparable, V any] struct {
m map[K]awaitableValue[V]
l sync.Mutex
l sync.RWMutex
}

// Set is the equivalent of `map[key] = val`.
Expand Down Expand Up @@ -136,6 +136,13 @@ func (s *shard[K, V]) Set(key K, val V, overwrite bool) (V, bool) {
// Exactly one of the target or channel will be returned.
// The third value is true if it is the first call that is waiting on this value.
func (s *shard[K, V]) Get(key K) (val V, wait <-chan struct{}, first bool) {
s.l.RLock()
if v, ok := s.m[key]; ok {
s.l.RUnlock()
return v.Val, v.Wait, false
}
s.l.RUnlock()

s.l.Lock()
defer s.l.Unlock()
if v, ok := s.m[key]; ok {
Expand All @@ -150,8 +157,8 @@ func (s *shard[K, V]) Get(key K) (val V, wait <-chan struct{}, first bool) {

// Values returns a copy of all the targets currently in the map.
func (s *shard[K, V]) Values() []V {
s.l.Lock()
defer s.l.Unlock()
s.l.RLock()
defer s.l.RUnlock()
ret := make([]V, 0, len(s.m))
for _, v := range s.m {
if v.Wait == nil {
Expand Down
33 changes: 33 additions & 0 deletions src/cmap/cmap_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package cmap

import (
"fmt"
"sort"
"strconv"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)

func hashInts(k int) uint64 {
Expand Down Expand Up @@ -84,3 +86,34 @@ func TestResize(t *testing.T) {
})
}
}

func BenchmarkMapInsertsAndGets(b *testing.B) {
// Attempts to mimic a vaguely realistic blend of writes and (more) reads.
m := New[int, int](DefaultShardCount, hashInts)
var wg, rg errgroup.Group
wg.SetLimit(3)
rg.SetLimit(12)
for i := 0; i < b.N; i++ {
x := i
for j := 0; j < 10; j++ {
wg.Go(func() error {
for k := 0; k < 1000; k++ {
m.Set(x, x)
}
return nil
})
}
for j := 0; j < 100; j++ {
rg.Go(func() error {
for k := 0; k < 1000; k++ {
if y := m.Get(x); y != x && y != 0 {
return fmt.Errorf("incorrect result, was %d, should be %d", y, x)
}
}
return nil
})
}
}
assert.NoError(b, wg.Wait())
assert.NoError(b, rg.Wait())
}

0 comments on commit 36cecf6

Please sign in to comment.