diff --git a/README.md b/README.md index 77e80e25..7628c9dd 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ The fields have to be one of the types that the sync package supports in order t - sync.Bool, allows for concurrent bool manipulation - sync.Secret, allows for concurrent secret manipulation. Secrets can only be strings - sync.TimeDuration, allows for concurrent time.duration manipulation. +- sync.StringMap, allows for concurrent map[string]string manipulation. For sensitive configuration (passwords, tokens, etc.) that shouldn't be printed in log, you can use the `Secret` flavor of `sync` types. If one of these is selected, then at harvester log instead of the real value the text `***` will be displayed. diff --git a/sync/sync.go b/sync/sync.go index 38eacdbd..4133e7a1 100644 --- a/sync/sync.go +++ b/sync/sync.go @@ -2,8 +2,10 @@ package sync import ( + "bytes" "fmt" "strconv" + "strings" "sync" "time" ) @@ -222,3 +224,55 @@ func (s *Secret) SetString(val string) error { s.Set(val) return nil } + +// StringMap is a map[string]string type with concurrent access support. +type StringMap struct { + rw sync.RWMutex + value map[string]string +} + +// Get returns the internal value. +func (s *StringMap) Get() map[string]string { + s.rw.RLock() + defer s.rw.RUnlock() + return s.value +} + +// Set a value. +func (s *StringMap) Set(value map[string]string) { + s.rw.Lock() + defer s.rw.Unlock() + s.value = value +} + +// String returns a string representation of the value. +func (s *StringMap) String() string { + s.rw.RLock() + defer s.rw.RUnlock() + b := new(bytes.Buffer) + firstChar := "" + for key, value := range s.value { + _, _ = fmt.Fprintf(b, "%s%s=%q", firstChar, key, value) + firstChar = "," + } + return b.String() +} + +// SetString parses and sets a value from string type. +func (s *StringMap) SetString(val string) error { + dict := make(map[string]string) + if val == "" || strings.TrimSpace(val) == "" { + s.Set(dict) + return nil + } + for _, pair := range strings.Split(val, ",") { + items := strings.SplitN(pair, ":", 2) + if len(items) != 2 { + return fmt.Errorf("map must be formatted as `key:value`, got %q", pair) + } + key, value := strings.TrimSpace(items[0]), strings.TrimSpace(items[1]) + dict[key] = value + } + s.Set(dict) + return nil +} diff --git a/sync/sync_test.go b/sync/sync_test.go index b1e527fe..d6fc0040 100644 --- a/sync/sync_test.go +++ b/sync/sync_test.go @@ -119,3 +119,55 @@ func TestTimeDuration_SetString(t *testing.T) { assert.NoError(t, f.SetString("3s")) assert.Equal(t, 3*time.Second, f.Get()) } + +func TestStringMap(t *testing.T) { + var sm StringMap + ch := make(chan struct{}) + go func() { + sm.Set(map[string]string{"key": "value"}) + ch <- struct{}{} + }() + <-ch + assert.Equal(t, map[string]string{"key": "value"}, sm.Get()) + assert.Equal(t, "key=\"value\"", sm.String()) +} + +func TestStringMap_SetString(t *testing.T) { + tests := []struct { + name string + input string + result map[string]string + throwsError bool + }{ + {"empty", "", map[string]string{}, false}, + {"empty with spaces", " ", map[string]string{}, false}, + {"single item", "key:value", map[string]string{"key": "value"}, false}, + {"single item with route as val", "key:http://thing", map[string]string{"key": "http://thing"}, false}, + {"key without value", "key", nil, true}, + {"multiple items", "key1:value,key2:value", map[string]string{"key1": "value", "key2": "value"}, false}, + {"multiple items with spaces", " key1 : value , key2 :value ", map[string]string{"key1": "value", "key2": "value"}, false}, + {"multiple urls", "key1:http://one,key2:https://two", map[string]string{"key1": "http://one", "key2": "https://two"}, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sm := StringMap{} + + err := sm.SetString(test.input) + if test.throwsError { + assert.Error(t, err) + } + + assert.Equal(t, test.result, sm.Get()) + }) + } +} + +func TestStringMap_SetString_DoesntOverrideValueIfError(t *testing.T) { + sm := StringMap{} + + assert.NoError(t, sm.SetString("k1:v1")) + assert.Equal(t, map[string]string{"k1": "v1"}, sm.Get()) + + assert.Error(t, sm.SetString("k1:v1,k2:v2,k3")) + assert.Equal(t, map[string]string{"k1": "v1"}, sm.Get()) +}