diff --git a/concurrent_map.go b/concurrent_map.go index 0e2319c..9365ec9 100644 --- a/concurrent_map.go +++ b/concurrent_map.go @@ -5,11 +5,13 @@ import ( "sync" ) -var SHARD_COUNT = 32 +var DEFAULT_SHARD_COUNT = 32 // A "thread" safe map of type string:Anything. -// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. -type ConcurrentMap []*ConcurrentMapShared +// To avoid lock bottlenecks this map is dived to several (m.numShards()) map shards. +type ConcurrentMap struct { + maps []*ConcurrentMapShared +} // A "thread" safe string to anything map. type ConcurrentMapShared struct { @@ -19,16 +21,33 @@ type ConcurrentMapShared struct { // Creates a new concurrent map. func New() ConcurrentMap { - m := make(ConcurrentMap, SHARD_COUNT) - for i := 0; i < SHARD_COUNT; i++ { - m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} + m := ConcurrentMap{ + maps: make([]*ConcurrentMapShared, DEFAULT_SHARD_COUNT), + } + for i := 0; i < DEFAULT_SHARD_COUNT; i++ { + m.maps[i] = &ConcurrentMapShared{items: make(map[string]interface{})} } return m } +// Creates a new concurrent map with size shards. +func NewWithSize(size int) ConcurrentMap { + m := ConcurrentMap{ + maps: make([]*ConcurrentMapShared, size), + } + for i := 0; i < size; i++ { + m.maps[i] = &ConcurrentMapShared{items: make(map[string]interface{})} + } + return m +} + +func (m ConcurrentMap) numShards() int { + return len(m.maps) +} + // Returns shard under given key func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { - return m[uint(fnv32(key))%uint(SHARD_COUNT)] + return m.maps[uint(fnv32(key))%uint(m.numShards())] } func (m ConcurrentMap) MSet(data map[string]interface{}) { @@ -93,8 +112,8 @@ func (m ConcurrentMap) Get(key string) (interface{}, bool) { // Returns the number of elements within the map. func (m ConcurrentMap) Count() int { count := 0 - for i := 0; i < SHARD_COUNT; i++ { - shard := m[i] + for i := 0; i < m.numShards(); i++ { + shard := m.maps[i] shard.RLock() count += len(shard.items) shard.RUnlock() @@ -191,11 +210,11 @@ func (m ConcurrentMap) IterBuffered() <-chan Tuple { // It returns once the size of each buffered channel is determined, // before all the channels are populated using goroutines. func snapshot(m ConcurrentMap) (chans []chan Tuple) { - chans = make([]chan Tuple, SHARD_COUNT) + chans = make([]chan Tuple, m.numShards()) wg := sync.WaitGroup{} - wg.Add(SHARD_COUNT) + wg.Add(m.numShards()) // Foreach shard. - for index, shard := range m { + for index, shard := range m.maps { go func(index int, shard *ConcurrentMapShared) { // Foreach key, value pair. shard.RLock() @@ -249,8 +268,8 @@ type IterCb func(key string, v interface{}) // Callback based iterator, cheapest way to read // all elements in a map. func (m ConcurrentMap) IterCb(fn IterCb) { - for idx := range m { - shard := (m)[idx] + for idx := range m.maps { + shard := (m).maps[idx] shard.RLock() for key, value := range shard.items { fn(key, value) @@ -266,8 +285,8 @@ func (m ConcurrentMap) Keys() []string { go func() { // Foreach shard. wg := sync.WaitGroup{} - wg.Add(SHARD_COUNT) - for _, shard := range m { + wg.Add(m.numShards()) + for _, shard := range m.maps { go func(shard *ConcurrentMapShared) { // Foreach key, value pair. shard.RLock() diff --git a/concurrent_map_bench_test.go b/concurrent_map_bench_test.go index 47cb8d8..b5d258c 100644 --- a/concurrent_map_bench_test.go +++ b/concurrent_map_bench_test.go @@ -177,10 +177,10 @@ func GetSet(m ConcurrentMap, finished chan struct{}) (set func(key, value string } func runWithShards(bench func(b *testing.B), b *testing.B, shardsCount int) { - oldShardsCount := SHARD_COUNT - SHARD_COUNT = shardsCount + oldShardsCount := DEFAULT_SHARD_COUNT + DEFAULT_SHARD_COUNT = shardsCount bench(b) - SHARD_COUNT = oldShardsCount + DEFAULT_SHARD_COUNT = oldShardsCount } func BenchmarkKeys(b *testing.B) { diff --git a/concurrent_map_test.go b/concurrent_map_test.go index c3d5f05..0bf0fd9 100644 --- a/concurrent_map_test.go +++ b/concurrent_map_test.go @@ -14,7 +14,7 @@ type Animal struct { func TestMapCreation(t *testing.T) { m := New() - if m == nil { + if m.maps == nil { t.Error("map is null.") } @@ -426,9 +426,9 @@ func TestConcurrent(t *testing.T) { } func TestJsonMarshal(t *testing.T) { - SHARD_COUNT = 2 + DEFAULT_SHARD_COUNT = 2 defer func() { - SHARD_COUNT = 32 + DEFAULT_SHARD_COUNT = 32 }() expected := "{\"a\":1,\"b\":2}" m := New()