diff --git a/.gitignore b/.gitignore index e479b4ad..2fe84b7c 100644 --- a/.gitignore +++ b/.gitignore @@ -57,4 +57,3 @@ go.work # GoReleaser dist/ - diff --git a/Makefile b/Makefile index 589417aa..8d8d6e4e 100644 --- a/Makefile +++ b/Makefile @@ -29,3 +29,9 @@ release: .PHONY: clean clean: go clean -testcache + +.PHONY: generate +generate: + go get github.com/maxbrunsfeld/counterfeiter/v6 + go generate ./... + go mod tidy diff --git a/go.mod b/go.mod index cde9af1d..3524b614 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.22 require ( github.com/DataDog/datadog-go v4.8.3+incompatible - github.com/artie-labs/transfer v1.24.2 + github.com/artie-labs/transfer v1.24.4 github.com/aws/aws-sdk-go v1.44.327 github.com/aws/aws-sdk-go-v2 v1.18.1 github.com/aws/aws-sdk-go-v2/config v1.18.19 diff --git a/go.sum b/go.sum index 09ac8565..93bac16c 100644 --- a/go.sum +++ b/go.sum @@ -91,8 +91,8 @@ github.com/apache/thrift v0.0.0-20181112125854-24918abba929/go.mod h1:cp2SuWMxlE github.com/apache/thrift v0.14.2/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY= github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= -github.com/artie-labs/transfer v1.24.2 h1:FbGxHbx7hEwxMN/X22QJztnrGr0W6plrEB7leb7e2Eo= -github.com/artie-labs/transfer v1.24.2/go.mod h1:mlDGYVa9CH93Rrcsh2bxrZW6HxSG65lnUjSOot8+oIc= +github.com/artie-labs/transfer v1.24.4 h1:JvHDV4g+MduJyeKHl7TGf6JzsoFejgktp1u2PqUoKIQ= +github.com/artie-labs/transfer v1.24.4/go.mod h1:mlDGYVa9CH93Rrcsh2bxrZW6HxSG65lnUjSOot8+oIc= github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go v1.44.327 h1:ZS8oO4+7MOBLhkdwIhgtVeDzCeWOlTfKJS7EgggbIEY= github.com/aws/aws-sdk-go v1.44.327/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= diff --git a/lib/mocks/client.mock.go b/lib/mocks/client.mock.go new file mode 100644 index 00000000..a7320cd1 --- /dev/null +++ b/lib/mocks/client.mock.go @@ -0,0 +1,208 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package mocks + +import ( + "sync" + "time" + + "github.com/artie-labs/reader/lib/mtr" +) + +type FakeClient struct { + CountStub func(string, int64, map[string]string) + countMutex sync.RWMutex + countArgsForCall []struct { + arg1 string + arg2 int64 + arg3 map[string]string + } + GaugeStub func(string, float64, map[string]string) + gaugeMutex sync.RWMutex + gaugeArgsForCall []struct { + arg1 string + arg2 float64 + arg3 map[string]string + } + IncrStub func(string, map[string]string) + incrMutex sync.RWMutex + incrArgsForCall []struct { + arg1 string + arg2 map[string]string + } + TimingStub func(string, time.Duration, map[string]string) + timingMutex sync.RWMutex + timingArgsForCall []struct { + arg1 string + arg2 time.Duration + arg3 map[string]string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeClient) Count(arg1 string, arg2 int64, arg3 map[string]string) { + fake.countMutex.Lock() + fake.countArgsForCall = append(fake.countArgsForCall, struct { + arg1 string + arg2 int64 + arg3 map[string]string + }{arg1, arg2, arg3}) + stub := fake.CountStub + fake.recordInvocation("Count", []interface{}{arg1, arg2, arg3}) + fake.countMutex.Unlock() + if stub != nil { + fake.CountStub(arg1, arg2, arg3) + } +} + +func (fake *FakeClient) CountCallCount() int { + fake.countMutex.RLock() + defer fake.countMutex.RUnlock() + return len(fake.countArgsForCall) +} + +func (fake *FakeClient) CountCalls(stub func(string, int64, map[string]string)) { + fake.countMutex.Lock() + defer fake.countMutex.Unlock() + fake.CountStub = stub +} + +func (fake *FakeClient) CountArgsForCall(i int) (string, int64, map[string]string) { + fake.countMutex.RLock() + defer fake.countMutex.RUnlock() + argsForCall := fake.countArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeClient) Gauge(arg1 string, arg2 float64, arg3 map[string]string) { + fake.gaugeMutex.Lock() + fake.gaugeArgsForCall = append(fake.gaugeArgsForCall, struct { + arg1 string + arg2 float64 + arg3 map[string]string + }{arg1, arg2, arg3}) + stub := fake.GaugeStub + fake.recordInvocation("Gauge", []interface{}{arg1, arg2, arg3}) + fake.gaugeMutex.Unlock() + if stub != nil { + fake.GaugeStub(arg1, arg2, arg3) + } +} + +func (fake *FakeClient) GaugeCallCount() int { + fake.gaugeMutex.RLock() + defer fake.gaugeMutex.RUnlock() + return len(fake.gaugeArgsForCall) +} + +func (fake *FakeClient) GaugeCalls(stub func(string, float64, map[string]string)) { + fake.gaugeMutex.Lock() + defer fake.gaugeMutex.Unlock() + fake.GaugeStub = stub +} + +func (fake *FakeClient) GaugeArgsForCall(i int) (string, float64, map[string]string) { + fake.gaugeMutex.RLock() + defer fake.gaugeMutex.RUnlock() + argsForCall := fake.gaugeArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeClient) Incr(arg1 string, arg2 map[string]string) { + fake.incrMutex.Lock() + fake.incrArgsForCall = append(fake.incrArgsForCall, struct { + arg1 string + arg2 map[string]string + }{arg1, arg2}) + stub := fake.IncrStub + fake.recordInvocation("Incr", []interface{}{arg1, arg2}) + fake.incrMutex.Unlock() + if stub != nil { + fake.IncrStub(arg1, arg2) + } +} + +func (fake *FakeClient) IncrCallCount() int { + fake.incrMutex.RLock() + defer fake.incrMutex.RUnlock() + return len(fake.incrArgsForCall) +} + +func (fake *FakeClient) IncrCalls(stub func(string, map[string]string)) { + fake.incrMutex.Lock() + defer fake.incrMutex.Unlock() + fake.IncrStub = stub +} + +func (fake *FakeClient) IncrArgsForCall(i int) (string, map[string]string) { + fake.incrMutex.RLock() + defer fake.incrMutex.RUnlock() + argsForCall := fake.incrArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeClient) Timing(arg1 string, arg2 time.Duration, arg3 map[string]string) { + fake.timingMutex.Lock() + fake.timingArgsForCall = append(fake.timingArgsForCall, struct { + arg1 string + arg2 time.Duration + arg3 map[string]string + }{arg1, arg2, arg3}) + stub := fake.TimingStub + fake.recordInvocation("Timing", []interface{}{arg1, arg2, arg3}) + fake.timingMutex.Unlock() + if stub != nil { + fake.TimingStub(arg1, arg2, arg3) + } +} + +func (fake *FakeClient) TimingCallCount() int { + fake.timingMutex.RLock() + defer fake.timingMutex.RUnlock() + return len(fake.timingArgsForCall) +} + +func (fake *FakeClient) TimingCalls(stub func(string, time.Duration, map[string]string)) { + fake.timingMutex.Lock() + defer fake.timingMutex.Unlock() + fake.TimingStub = stub +} + +func (fake *FakeClient) TimingArgsForCall(i int) (string, time.Duration, map[string]string) { + fake.timingMutex.RLock() + defer fake.timingMutex.RUnlock() + argsForCall := fake.timingArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeClient) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.countMutex.RLock() + defer fake.countMutex.RUnlock() + fake.gaugeMutex.RLock() + defer fake.gaugeMutex.RUnlock() + fake.incrMutex.RLock() + defer fake.incrMutex.RUnlock() + fake.timingMutex.RLock() + defer fake.timingMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeClient) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ mtr.Client = new(FakeClient) diff --git a/lib/mocks/generate.go b/lib/mocks/generate.go new file mode 100644 index 00000000..31972167 --- /dev/null +++ b/lib/mocks/generate.go @@ -0,0 +1,4 @@ +package mocks + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate +//counterfeiter:generate -o=client.mock.go ../mtr Client diff --git a/writers/transfer/writer.go b/writers/transfer/writer.go index 0639b59a..c3e59535 100644 --- a/writers/transfer/writer.go +++ b/writers/transfer/writer.go @@ -27,6 +27,8 @@ type Writer struct { inMemDB *models.DatabaseData tc *kafkalib.TopicConfig destination destination.DataWarehouse + + primaryKeys []string } func NewWriter(cfg config.Config, statsD mtr.Client) (*Writer, error) { @@ -38,7 +40,7 @@ func NewWriter(cfg config.Config, statsD mtr.Client) (*Writer, error) { return nil, fmt.Errorf("kafka config should have exactly one topic config") } - destination, err := utils.LoadDataWarehouse(cfg, nil) + _destination, err := utils.LoadDataWarehouse(cfg, nil) if err != nil { return nil, err } @@ -48,7 +50,7 @@ func NewWriter(cfg config.Config, statsD mtr.Client) (*Writer, error) { statsD: statsD, inMemDB: models.NewMemoryDB(), tc: cfg.Kafka.TopicConfigs[0], - destination: destination, + destination: _destination, }, nil } @@ -98,13 +100,23 @@ func (w *Writer) Write(_ context.Context, messages []lib.RawMessage) error { }() for _, evt := range events { + // Set the primary keys if it's not set already. + if len(w.primaryKeys) == 0 { + var pks []string + for key := range evt.PrimaryKeyMap { + pks = append(pks, key) + } + + w.primaryKeys = pks + } + shouldFlush, flushReason, err := evt.Save(w.cfg, w.inMemDB, w.tc, artie.Message{}) if err != nil { return fmt.Errorf("failed to save event: %w", err) } if shouldFlush { - if err := w.flush(flushReason); err != nil { + if err = w.flush(flushReason); err != nil { return err } } @@ -154,7 +166,7 @@ func (w *Writer) flush(reason string) error { } tableData.ResetTempTableSuffix() - if err := w.destination.Append(tableData.TableData); err != nil { + if err = w.destination.Append(tableData.TableData); err != nil { tags["what"] = "merge_fail" tags["retryable"] = fmt.Sprint(w.destination.IsRetryableError(err)) return fmt.Errorf("failed to append data to destination: %w", err) @@ -164,6 +176,10 @@ func (w *Writer) flush(reason string) error { } func (w *Writer) OnComplete() error { + if len(w.primaryKeys) == 0 { + return fmt.Errorf("primary keys not set") + } + if err := w.flush("complete"); err != nil { return err } @@ -176,7 +192,7 @@ func (w *Writer) OnComplete() error { slog.Info("Running dedupe...", slog.String("table", tableName)) tableID := w.destination.IdentifierFor(*w.tc, tableName) start := time.Now() - if err = w.destination.Dedupe(tableID); err != nil { + if err = w.destination.Dedupe(tableID, w.primaryKeys, *w.tc); err != nil { return err } slog.Info("Dedupe complete", slog.String("table", tableName), slog.Duration("duration", time.Since(start))) diff --git a/writers/transfer/writer_test.go b/writers/transfer/writer_test.go index 636b49aa..af353c3c 100644 --- a/writers/transfer/writer_test.go +++ b/writers/transfer/writer_test.go @@ -1,6 +1,10 @@ package transfer import ( + "context" + "github.com/artie-labs/reader/lib" + "github.com/artie-labs/reader/lib/mocks" + "github.com/artie-labs/transfer/lib/cdc/util" "testing" transferCfg "github.com/artie-labs/transfer/lib/config" @@ -45,3 +49,41 @@ func TestWriter_MessageToEvent(t *testing.T) { "string": "Hello, world!", }, evtOut.Data) } + +func TestWriter_Write(t *testing.T) { + var rawMsgs []lib.RawMessage + for range 100 { + rawMsgs = append(rawMsgs, lib.NewRawMessage( + "topic-suffix", + map[string]any{"key": "value"}, + &util.SchemaEventPayload{ + Payload: util.Payload{ + After: map[string]any{"a": "b"}, + Source: util.Source{ + TsMs: 1000, + Table: "table", + }, + Operation: "c", + }, + }, + )) + } + + writer, err := NewWriter(transferCfg.Config{ + Mode: transferCfg.Replication, + Output: "test", + Kafka: &transferCfg.Kafka{ + TopicConfigs: []*kafkalib.TopicConfig{ + { + TableName: "table", + }, + }, + }, + }, &mocks.FakeClient{}) + assert.NoError(t, err) + + assert.Nil(t, writer.primaryKeys) + assert.NoError(t, writer.Write(context.Background(), rawMsgs)) + assert.Len(t, writer.primaryKeys, 1) + assert.Equal(t, "key", writer.primaryKeys[0]) +}