diff --git a/lib/persistedmap/persistedmap.go b/lib/persistedmap/persistedmap.go new file mode 100644 index 00000000..e2f83688 --- /dev/null +++ b/lib/persistedmap/persistedmap.go @@ -0,0 +1,120 @@ +package persistedmap + +import ( + "fmt" + "gopkg.in/yaml.v3" + "io" + "log/slog" + "os" + "sync" + "time" + + "github.com/artie-labs/reader/lib/logger" +) + +type PersistedMap struct { + filePath string + shouldSave bool + mu sync.RWMutex + data map[string]any + flushTicker *time.Ticker +} + +func NewPersistedMap(filePath string) *PersistedMap { + persistedMap := &PersistedMap{ + filePath: filePath, + data: make(map[string]any), + } + + if err := persistedMap.loadFromFile(); err != nil { + slog.Warn("Failed to load persisted map from filepath, starting a new one...", slog.Any("err", err)) + } + + persistedMap.flushTicker = time.NewTicker(30 * time.Second) + go persistedMap.flushRoutine() + + return persistedMap +} + +func (p *PersistedMap) Set(key string, value any) { + p.mu.Lock() + defer p.mu.Unlock() + + p.data[key] = value + p.shouldSave = true +} + +func (p *PersistedMap) Get(key string) (any, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + value, isOk := p.data[key] + return value, isOk +} + +func (p *PersistedMap) flushRoutine() { + for { + select { + case <-p.flushTicker.C: + if err := p.flush(); err != nil { + logger.Panic("failed to flush data", slog.Any("err", err)) + } + } + } +} + +func (p *PersistedMap) flush() error { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.shouldSave { + return nil + } + + file, err := os.Create(p.filePath) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + + yamlBytes, err := yaml.Marshal(p.data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + + if _, err = file.Write(yamlBytes); err != nil { + return fmt.Errorf("failed to write to file: %w", err) + } + + defer file.Close() + + p.shouldSave = false + return nil +} + +func (p *PersistedMap) loadFromFile() error { + file, err := os.Open(p.filePath) + if err != nil { + return err + } + + defer file.Close() + + readBytes, err := io.ReadAll(file) + if err != nil { + return fmt.Errorf("failed to read file: %w", err) + } + + var data map[string]any + if err = yaml.Unmarshal(readBytes, &data); err != nil { + return fmt.Errorf("failed to unmarshal data: %w", err) + } + + if data == nil { + data = make(map[string]any) + } + + p.mu.Lock() + p.data = data + p.mu.Unlock() + return nil +} diff --git a/lib/persistedmap/persistedmap_test.go b/lib/persistedmap/persistedmap_test.go new file mode 100644 index 00000000..56c6b033 --- /dev/null +++ b/lib/persistedmap/persistedmap_test.go @@ -0,0 +1,60 @@ +package persistedmap + +import ( + "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" + "os" + "testing" +) + +func TestPersistedMap_LoadFromFile(t *testing.T) { + tmpFile, err := os.CreateTemp("", "persistedmap_test") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // Write initial data to the file + initialData := map[string]any{"key1": "value1", "key2": 2} + yamlBytes, err := yaml.Marshal(initialData) + assert.NoError(t, err) + _, err = tmpFile.Write(yamlBytes) + assert.NoError(t, err) + tmpFile.Close() + + // Load the data from the file into PersistedMap + pMap := NewPersistedMap(tmpFile.Name()) + pMap.mu.Lock() + defer pMap.mu.Unlock() + assert.Equal(t, initialData, pMap.data) +} + +func TestPersistedMap_Flush(t *testing.T) { + tmpFile, err := os.CreateTemp("", "persistedmap_test") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + pMap := NewPersistedMap(tmpFile.Name()) + pMap.Set("key1", "value1") + pMap.Set("key2", 2) + + assert.NoError(t, pMap.flush()) + + // Does the data exist? + + val, isOk := pMap.Get("key1") + assert.True(t, isOk) + assert.Equal(t, "value1", val) + + val, isOk = pMap.Get("key2") + assert.Equal(t, 2, val) + assert.True(t, isOk) + + // If I load a new persisted map, does it come back? + pMap2 := NewPersistedMap(tmpFile.Name()) + val, isOk = pMap2.Get("key1") + assert.True(t, isOk) + assert.Equal(t, "value1", val) + + val, isOk = pMap2.Get("key2") + assert.Equal(t, 2, val) + assert.True(t, isOk) +}