From a9b4d5929c48f02fa21213f8e9d2f30f72cdf1dc Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Fri, 29 Mar 2024 12:44:33 -0700 Subject: [PATCH] [dynamodb] Use batch iterator for writing (#320) --- lib/debezium/transformer/transformer_test.go | 42 ++++------ lib/types.go | 30 ++++++++ lib/types_test.go | 81 ++++++++++++++++++++ lib/writer/writer_test.go | 33 +++----- sources/dynamodb/shard.go | 11 +-- sources/dynamodb/snapshot.go | 8 +- sources/dynamodb/stream.go | 3 +- sources/postgres/adapter/transformer_test.go | 32 ++------ 8 files changed, 157 insertions(+), 83 deletions(-) create mode 100644 lib/types_test.go diff --git a/lib/debezium/transformer/transformer_test.go b/lib/debezium/transformer/transformer_test.go index 7cae82c5..aba6fdc2 100644 --- a/lib/debezium/transformer/transformer_test.go +++ b/lib/debezium/transformer/transformer_test.go @@ -8,6 +8,7 @@ import ( "github.com/artie-labs/transfer/lib/debezium" "github.com/stretchr/testify/assert" + "github.com/artie-labs/reader/lib" "github.com/artie-labs/reader/lib/debezium/converters" ) @@ -57,33 +58,20 @@ func (m mockAdatper) NewIterator() (RowsIterator, error) { return m.iter, nil } -type mockIterator struct { - returnErr bool - index int - batches [][]Row -} +type errorIterator struct{} -func (m *mockIterator) HasNext() bool { - return m.index < len(m.batches) +func (m *errorIterator) HasNext() bool { + return true } -func (m *mockIterator) Next() ([]Row, error) { - if m.returnErr { - return nil, fmt.Errorf("test iteration error") - } - - if !m.HasNext() { - return nil, fmt.Errorf("done") - } - result := m.batches[m.index] - m.index++ - return result, nil +func (m *errorIterator) Next() ([]Row, error) { + return nil, fmt.Errorf("test iteration error") } func TestDebeziumTransformer_Iteration(t *testing.T) { { // Empty iterator - transformer, err := NewDebeziumTransformer(mockAdatper{iter: &mockIterator{}}) + transformer, err := NewDebeziumTransformer(mockAdatper{iter: lib.NewBatchIterator([][]Row{})}) assert.NoError(t, err) assert.False(t, transformer.HasNext()) rows, err := transformer.Next() @@ -93,7 +81,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { { // One empty batch batches := [][]Row{{}} - transformer, err := NewDebeziumTransformer(mockAdatper{iter: &mockIterator{batches: batches}}) + transformer, err := NewDebeziumTransformer(mockAdatper{iter: lib.NewBatchIterator(batches)}) assert.NoError(t, err) assert.True(t, transformer.HasNext()) rows, err := transformer.Next() @@ -116,7 +104,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { }} transformer, err := NewDebeziumTransformer(mockAdatper{ fieldConverters: fieldConverters, - iter: &mockIterator{batches: batches}, + iter: lib.NewBatchIterator(batches), }) assert.NoError(t, err) // First batch @@ -153,7 +141,7 @@ func TestDebeziumTransformer_Iteration(t *testing.T) { } transformer, err := NewDebeziumTransformer(mockAdatper{ fieldConverters: fieldConverters, - iter: &mockIterator{batches: batches}, + iter: lib.NewBatchIterator(batches), }) assert.NoError(t, err) // First batch @@ -192,7 +180,7 @@ func TestDebeziumTransformer_Next(t *testing.T) { mockAdatper{ fieldConverters: fieldConverters, partitionKeys: []string{"foo"}, - iter: &mockIterator{batches: [][]Row{{{"foo": "bar"}}}, returnErr: true}, + iter: &errorIterator{}, }, ) assert.NoError(t, err) @@ -208,7 +196,7 @@ func TestDebeziumTransformer_Next(t *testing.T) { transformer, err := NewDebeziumTransformer(mockAdatper{ fieldConverters: fieldConverters, partitionKeys: []string{"foo"}, - iter: &mockIterator{batches: [][]Row{{{"foo": "bar"}}}}, + iter: lib.NewSingleBatchIterator([]Row{{"foo": "bar"}}), }, ) assert.NoError(t, err) @@ -229,7 +217,7 @@ func TestDebeziumTransformer_Next(t *testing.T) { transformer, err := NewDebeziumTransformer(mockAdatper{ fieldConverters: fieldConverters, partitionKeys: []string{"foo", "qux"}, - iter: &mockIterator{batches: batches}, + iter: lib.NewBatchIterator(batches), }, ) assert.NoError(t, err) @@ -277,7 +265,7 @@ func TestDebeziumTransformer_CreatePayload(t *testing.T) { fieldConverters := []FieldConverter{ {Name: "qux", ValueConverter: testConverter{intField: true, returnErr: true}}, } - transformer, err := NewDebeziumTransformer(mockAdatper{fieldConverters: fieldConverters, iter: &mockIterator{}}) + transformer, err := NewDebeziumTransformer(mockAdatper{fieldConverters: fieldConverters, iter: lib.NewBatchIterator([][]Row{})}) assert.NoError(t, err) _, err = transformer.createPayload(Row{"qux": "quux"}) assert.ErrorContains(t, err, `failed to convert row value for key "qux": test error`) @@ -288,7 +276,7 @@ func TestDebeziumTransformer_CreatePayload(t *testing.T) { {Name: "foo", ValueConverter: testConverter{intField: false}}, {Name: "qux", ValueConverter: testConverter{intField: true}}, } - transformer, err := NewDebeziumTransformer(mockAdatper{fieldConverters: fieldConverters, iter: &mockIterator{}}) + transformer, err := NewDebeziumTransformer(mockAdatper{fieldConverters: fieldConverters, iter: lib.NewBatchIterator([][]Row{})}) assert.NoError(t, err) payload, err := transformer.createPayload(Row{"foo": "bar", "qux": "quux"}) assert.NoError(t, err) diff --git a/lib/types.go b/lib/types.go index d891ec40..3a2f73e9 100644 --- a/lib/types.go +++ b/lib/types.go @@ -1,6 +1,8 @@ package lib import ( + "fmt" + "github.com/artie-labs/transfer/lib/cdc/mongo" "github.com/artie-labs/transfer/lib/cdc/util" ) @@ -38,3 +40,31 @@ func (r RawMessage) GetPayload() any { return r.payload } + +type batchIterator[T any] struct { + index int + batches [][]T +} + +// Returns an iterator that produces multiple batches. +func NewBatchIterator[T any](batches [][]T) *batchIterator[T] { + return &batchIterator[T]{batches: batches} +} + +// Returns an iterator that produces a single batch. +func NewSingleBatchIterator[T any](batches []T) *batchIterator[T] { + return NewBatchIterator([][]T{batches}) +} + +func (bi *batchIterator[T]) HasNext() bool { + return bi.index < len(bi.batches) +} + +func (bi *batchIterator[T]) Next() ([]T, error) { + if !bi.HasNext() { + return nil, fmt.Errorf("iterator has finished") + } + result := bi.batches[bi.index] + bi.index++ + return result, nil +} diff --git a/lib/types_test.go b/lib/types_test.go new file mode 100644 index 00000000..3bcfe14d --- /dev/null +++ b/lib/types_test.go @@ -0,0 +1,81 @@ +package lib + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_BatchIterator(t *testing.T) { + { + // No batches + iter := NewBatchIterator([][]string{}) + assert.False(t, iter.HasNext()) + _, err := iter.Next() + assert.ErrorContains(t, err, "iterator has finished") + } + { + // One empty batch + iter := NewBatchIterator([][]string{{}}) + assert.True(t, iter.HasNext()) + batch, err := iter.Next() + assert.NoError(t, err) + assert.Empty(t, batch) + assert.False(t, iter.HasNext()) + _, err = iter.Next() + assert.ErrorContains(t, err, "iterator has finished") + } + { + // Two non-empty batches one empty batch + iter := NewBatchIterator([][]string{{"a", "b"}, {}, {"c", "d"}}) + assert.True(t, iter.HasNext()) + { + batch, err := iter.Next() + assert.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, batch) + } + + assert.True(t, iter.HasNext()) + { + batch, err := iter.Next() + assert.NoError(t, err) + assert.Empty(t, batch) + } + + assert.True(t, iter.HasNext()) + { + batch, err := iter.Next() + assert.NoError(t, err) + assert.Equal(t, []string{"c", "d"}, batch) + } + + assert.False(t, iter.HasNext()) + _, err := iter.Next() + assert.ErrorContains(t, err, "iterator has finished") + } +} + +func Test_SingleBatchIterator(t *testing.T) { + { + // Empty batch + iter := NewSingleBatchIterator([]string{}) + assert.True(t, iter.HasNext()) + batch, err := iter.Next() + assert.NoError(t, err) + assert.Empty(t, batch) + assert.False(t, iter.HasNext()) + _, err = iter.Next() + assert.ErrorContains(t, err, "iterator has finished") + } + { + // Non-empty batch + iter := NewSingleBatchIterator([]string{"a", "b", "c", "d"}) + assert.True(t, iter.HasNext()) + batch, err := iter.Next() + assert.NoError(t, err) + assert.Equal(t, []string{"a", "b", "c", "d"}, batch) + assert.False(t, iter.HasNext()) + _, err = iter.Next() + assert.ErrorContains(t, err, "iterator has finished") + } +} diff --git a/lib/writer/writer_test.go b/lib/writer/writer_test.go index 2bf8c142..06ac58fb 100644 --- a/lib/writer/writer_test.go +++ b/lib/writer/writer_test.go @@ -22,27 +22,14 @@ func (m *mockDestination) WriteRawMessages(ctx context.Context, msgs []lib.RawMe return nil } -type mockIterator struct { - emitError bool - index int - batches [][]lib.RawMessage -} +type errorIterator struct{} -func (m *mockIterator) HasNext() bool { - return m.index < len(m.batches) +func (m *errorIterator) HasNext() bool { + return true } -func (m *mockIterator) Next() ([]lib.RawMessage, error) { - if m.emitError { - return nil, fmt.Errorf("test iteration error") - } - - if !m.HasNext() { - return nil, fmt.Errorf("done") - } - result := m.batches[m.index] - m.index++ - return result, nil +func (m *errorIterator) Next() ([]lib.RawMessage, error) { + return nil, fmt.Errorf("test iteration error") } func TestWriter_Write(t *testing.T) { @@ -50,7 +37,7 @@ func TestWriter_Write(t *testing.T) { // Empty iterator destination := &mockDestination{} writer := New(destination) - iterator := &mockIterator{} + iterator := lib.NewBatchIterator([][]lib.RawMessage{}) count, err := writer.Write(context.Background(), iterator) assert.NoError(t, err) assert.Equal(t, 0, count) @@ -60,7 +47,7 @@ func TestWriter_Write(t *testing.T) { // Iteration error destination := &mockDestination{} writer := New(destination) - iterator := &mockIterator{emitError: true, batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}}} + iterator := &errorIterator{} _, err := writer.Write(context.Background(), iterator) assert.ErrorContains(t, err, "failed to iterate over messages: test iteration error") assert.Empty(t, destination.messages) @@ -69,7 +56,7 @@ func TestWriter_Write(t *testing.T) { // Two empty batches destination := &mockDestination{} writer := New(destination) - iterator := &mockIterator{batches: [][]lib.RawMessage{{}, {}}} + iterator := lib.NewBatchIterator([][]lib.RawMessage{{}, {}}) count, err := writer.Write(context.Background(), iterator) assert.NoError(t, err) assert.Equal(t, 0, count) @@ -79,7 +66,7 @@ func TestWriter_Write(t *testing.T) { // Three batches, two non-empty destination := &mockDestination{} writer := New(destination) - iterator := &mockIterator{batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}, {}, {{TopicSuffix: "b"}, {TopicSuffix: "c"}}}} + iterator := lib.NewBatchIterator([][]lib.RawMessage{{{TopicSuffix: "a"}}, {}, {{TopicSuffix: "b"}, {TopicSuffix: "c"}}}) count, err := writer.Write(context.Background(), iterator) assert.NoError(t, err) assert.Equal(t, 3, count) @@ -92,7 +79,7 @@ func TestWriter_Write(t *testing.T) { // Destination error destination := &mockDestination{emitError: true} writer := New(destination) - iterator := &mockIterator{batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}}} + iterator := lib.NewSingleBatchIterator([]lib.RawMessage{{TopicSuffix: "a"}}) _, err := writer.Write(context.Background(), iterator) assert.ErrorContains(t, err, "failed to write messages: test write-raw-messages error") assert.Empty(t, destination.messages) diff --git a/sources/dynamodb/shard.go b/sources/dynamodb/shard.go index 9d361b40..2d421c55 100644 --- a/sources/dynamodb/shard.go +++ b/sources/dynamodb/shard.go @@ -9,19 +9,19 @@ import ( "github.com/artie-labs/transfer/lib/ptr" "github.com/aws/aws-sdk-go/service/dynamodbstreams" - "github.com/artie-labs/reader/destinations" "github.com/artie-labs/reader/lib" "github.com/artie-labs/reader/lib/dynamo" "github.com/artie-labs/reader/lib/logger" + "github.com/artie-labs/reader/lib/writer" ) -func (s *StreamStore) ListenToChannel(ctx context.Context, destination destinations.Destination) { +func (s *StreamStore) ListenToChannel(ctx context.Context, _writer writer.Writer) { for shard := range s.shardChan { - go s.processShard(ctx, shard, destination) + go s.processShard(ctx, shard, _writer) } } -func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.Shard, destination destinations.Destination) { +func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.Shard, _writer writer.Writer) { var attempts int // Is there another go-routine processing this shard? @@ -97,7 +97,8 @@ func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.S messages = append(messages, msg.RawMessage()) } - if err = destination.WriteRawMessages(ctx, messages); err != nil { + // TODO: Create an actual iterator over the shards that is passed to the writer. + if _, err = _writer.Write(ctx, lib.NewSingleBatchIterator(messages)); err != nil { logger.Panic("Failed to publish messages, exiting...", slog.Any("err", err)) } diff --git a/sources/dynamodb/snapshot.go b/sources/dynamodb/snapshot.go index 57fdccc1..f2cb701c 100644 --- a/sources/dynamodb/snapshot.go +++ b/sources/dynamodb/snapshot.go @@ -13,6 +13,7 @@ import ( "github.com/artie-labs/reader/lib/dynamo" "github.com/artie-labs/reader/lib/logger" "github.com/artie-labs/reader/lib/s3lib" + "github.com/artie-labs/reader/lib/writer" ) type SnapshotStore struct { @@ -33,7 +34,7 @@ func (s *SnapshotStore) Run(ctx context.Context, destination destinations.Destin return fmt.Errorf("scanning files over bucket failed: %w", err) } - if err := s.streamAndPublish(ctx, destination); err != nil { + if err := s.streamAndPublish(ctx, writer.New(destination)); err != nil { return fmt.Errorf("stream and publish failed: %w", err) } @@ -64,7 +65,7 @@ func (s *SnapshotStore) scanFilesOverBucket() error { return nil } -func (s *SnapshotStore) streamAndPublish(ctx context.Context, destination destinations.Destination) error { +func (s *SnapshotStore) streamAndPublish(ctx context.Context, _writer writer.Writer) error { keys, err := s.retrievePrimaryKeys() if err != nil { return fmt.Errorf("failed to retrieve primary keys: %w", err) @@ -92,7 +93,8 @@ func (s *SnapshotStore) streamAndPublish(ctx context.Context, destination destin messages = append(messages, dynamoMsg.RawMessage()) } - if err = destination.WriteRawMessages(ctx, messages); err != nil { + // TODO: Create an actual iterator over the files that is passed to the writer. + if _, err := _writer.Write(ctx, lib.NewSingleBatchIterator(messages)); err != nil { return fmt.Errorf("failed to publish messages: %w", err) } diff --git a/sources/dynamodb/stream.go b/sources/dynamodb/stream.go index bc0d08e0..b634aa0f 100644 --- a/sources/dynamodb/stream.go +++ b/sources/dynamodb/stream.go @@ -8,6 +8,7 @@ import ( "github.com/artie-labs/reader/config" "github.com/artie-labs/reader/destinations" + "github.com/artie-labs/reader/lib/writer" "github.com/artie-labs/reader/sources/dynamodb/offsets" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodbstreams" @@ -31,7 +32,7 @@ func (s *StreamStore) Run(ctx context.Context, destination destinations.Destinat ticker := time.NewTicker(shardScannerInterval) // Start to subscribe to the channel - go s.ListenToChannel(ctx, destination) + go s.ListenToChannel(ctx, writer.New(destination)) // Scan it for the first time manually, so we don't have to wait 5 mins if err := s.scanForNewShards(); err != nil { diff --git a/sources/postgres/adapter/transformer_test.go b/sources/postgres/adapter/transformer_test.go index 89fd0b12..c98c2b50 100644 --- a/sources/postgres/adapter/transformer_test.go +++ b/sources/postgres/adapter/transformer_test.go @@ -7,6 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/cdc/util" "github.com/stretchr/testify/assert" + "github.com/artie-labs/reader/lib" "github.com/artie-labs/reader/lib/debezium/converters" "github.com/artie-labs/reader/lib/debezium/transformer" "github.com/artie-labs/reader/lib/postgres" @@ -23,21 +24,6 @@ func (m *ErrorRowIterator) Next() ([]map[string]any, error) { return nil, fmt.Errorf("mock error") } -type MockRowIterator struct { - batches [][]map[string]any - index int -} - -func (m *MockRowIterator) HasNext() bool { - return m.index < len(m.batches) -} - -func (m *MockRowIterator) Next() ([]map[string]any, error) { - result := m.batches[m.index] - m.index++ - return result, nil -} - func TestDebeziumTransformer(t *testing.T) { table := postgres.Table{ Schema: "schema", @@ -53,7 +39,7 @@ func TestDebeziumTransformer(t *testing.T) { { dbzTransformer := transformer.NewDebeziumTransformerWithIterator( PostgresAdapter{table: table}, - &MockRowIterator{batches: [][]map[string]any{}}, + lib.NewBatchIterator([][]transformer.Row{}), ) assert.False(t, dbzTransformer.HasNext()) } @@ -80,12 +66,10 @@ func TestDebeziumTransformer(t *testing.T) { {Name: "b", ValueConverter: converters.StringPassthrough{}}, }, }, - &MockRowIterator{ - batches: [][]map[string]any{ - {{"a": "1", "b": "11"}, {"a": "2", "b": "12"}}, - {{"a": "3", "b": "13"}, {"a": "4", "b": "14"}}, - }, - }, + lib.NewBatchIterator([][]transformer.Row{ + {{"a": "1", "b": "11"}, {"a": "2", "b": "12"}}, + {{"a": "3", "b": "13"}, {"a": "4", "b": "14"}}, + }), ) assert.True(t, dbzTransformer.HasNext()) @@ -124,7 +108,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) { }, } - rowData := map[string]any{ + rowData := transformer.Row{ "user_id": int16(123), "name": "Robin", } @@ -137,7 +121,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) { {Name: "name", ValueConverter: converters.StringPassthrough{}}, }, }, - &MockRowIterator{batches: [][]map[string]any{{rowData}}}, + lib.NewSingleBatchIterator([]transformer.Row{rowData}), ) rows, err := dbzTransformer.Next()