Skip to content

Commit

Permalink
Pull out iteration writing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Mar 28, 2024
1 parent f6b742b commit a8c8a43
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 49 deletions.
8 changes: 1 addition & 7 deletions destinations/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@ import (
"github.com/artie-labs/reader/lib"
)

type RawMessageIterator interface {
HasNext() bool
Next() ([]lib.RawMessage, error)
}

type DestinationWriter interface {
WriteIterator(ctx context.Context, iter RawMessageIterator) (int, error)
type Destination interface {
WriteRawMessages(ctx context.Context, rawMsgs []lib.RawMessage) error
}
20 changes: 0 additions & 20 deletions lib/kafkalib/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/segmentio/kafka-go"

"github.com/artie-labs/reader/config"
"github.com/artie-labs/reader/destinations"
"github.com/artie-labs/reader/lib"
"github.com/artie-labs/reader/lib/iterator"
"github.com/artie-labs/reader/lib/mtr"
Expand Down Expand Up @@ -134,22 +133,3 @@ func (b *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) e
}
return nil
}

func (b *BatchWriter) WriteIterator(ctx context.Context, iter destinations.RawMessageIterator) (int, error) {
start := time.Now()
var count int
for iter.HasNext() {
msgs, err := iter.Next()
if err != nil {
return 0, fmt.Errorf("failed to iterate over messages: %w", err)

} else if len(msgs) > 0 {
if err = b.WriteRawMessages(ctx, msgs); err != nil {
return 0, fmt.Errorf("failed to write messages to kafka: %w", err)
}
count += len(msgs)
slog.Info("Scanning progress", slog.Duration("timing", time.Since(start)), slog.Int("count", count))
}
}
return count, nil
}
48 changes: 48 additions & 0 deletions lib/writer/writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package writer

import (
"context"
"fmt"
"log/slog"
"time"

"github.com/artie-labs/reader/destinations"
"github.com/artie-labs/reader/lib"
)

type RawMessageIterator interface {
HasNext() bool
Next() ([]lib.RawMessage, error)
}

type Writer struct {
destination destinations.Destination
}

func New(destination destinations.Destination) Writer {
return Writer{destination}
}

// Write writes all the messages from an iterator to the destination.
func (w *Writer) Write(ctx context.Context, iter RawMessageIterator) (int, error) {
start := time.Now()
var count int
for iter.HasNext() {
msgs, err := iter.Next()
if err != nil {
return 0, fmt.Errorf("failed to iterate over messages: %w", err)

} else if len(msgs) > 0 {
if err = w.destination.WriteRawMessages(ctx, msgs); err != nil {
return 0, fmt.Errorf("failed to write messages: %w", err)
}
count += len(msgs)
}
slog.Info("Write progress",
slog.Duration("timing", time.Since(start)),
slog.Int("batchSize", len(msgs)),
slog.Int("total", count),
)
}
return count, nil
}
97 changes: 97 additions & 0 deletions lib/writer/writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package writer

import (
"context"
"fmt"
"testing"

"github.com/artie-labs/reader/lib"
"github.com/stretchr/testify/assert"
)

type mockDestination struct {
messages []lib.RawMessage
emitError bool
}

func (m *mockDestination) WriteRawMessages(ctx context.Context, msgs []lib.RawMessage) error {
if m.emitError {
return fmt.Errorf("test write raw messages error")
}
m.messages = append(m.messages, msgs...)
return nil
}

type mockIterator struct {
emitError bool
index int
batches [][]lib.RawMessage
}

func (m *mockIterator) HasNext() bool {
return m.index < len(m.batches)
}

func (m *mockIterator) Next() ([]lib.RawMessage, error) {
if m.emitError {
return nil, fmt.Errorf("test iteration error")
}

if !m.HasNext() {
return nil, fmt.Errorf("done")
}
result := m.batches[m.index]
m.index++
return result, nil
}

func TestWriter_Write(t *testing.T) {
{
// Empty iterator
destination := &mockDestination{}
writer := New(destination)
iterator := &mockIterator{}
count, err := writer.Write(context.Background(), iterator)
assert.NoError(t, err)
assert.Equal(t, 0, count)
assert.Empty(t, destination.messages)
}
{
// Iteration error
destination := &mockDestination{}
writer := New(destination)
iterator := &mockIterator{emitError: true, batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}}}
_, err := writer.Write(context.Background(), iterator)
assert.ErrorContains(t, err, "failed to iterate over messages: test iteration error")
assert.Empty(t, destination.messages)
}
{
// Two empty batches
destination := &mockDestination{}
writer := New(destination)
iterator := &mockIterator{batches: [][]lib.RawMessage{{}, {}}}
count, err := writer.Write(context.Background(), iterator)
assert.NoError(t, err)
assert.Equal(t, 0, count)
assert.Empty(t, destination.messages)
}
{
// Three batches, two non-empty
destination := &mockDestination{}
writer := New(destination)
iterator := &mockIterator{batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}, {}, {{TopicSuffix: "b"}, {TopicSuffix: "c"}}}}
count, err := writer.Write(context.Background(), iterator)
assert.NoError(t, err)
assert.Equal(t, 3, count)
assert.Len(t, destination.messages, 3)
}
{
// Destionation error
destination := &mockDestination{emitError: true}
writer := New(destination)
iterator := &mockIterator{batches: [][]lib.RawMessage{{{TopicSuffix: "a"}}}}
_, err := writer.Write(context.Background(), iterator)
assert.ErrorContains(t, err, "failed to write messages: test write raw messages error")
assert.Empty(t, destination.messages)
}
}
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func buildSource(cfg *config.Settings) (sources.Source, error) {
}
}

func buildDestination(ctx context.Context, cfg *config.Settings, statsD mtr.Client) (destinations.DestinationWriter, error) {
func buildDestination(ctx context.Context, cfg *config.Settings, statsD mtr.Client) (destinations.Destination, error) {
switch cfg.Destination {
case config.DestinationKafka:
kafkaCfg := cfg.Kafka
Expand Down Expand Up @@ -88,7 +88,7 @@ func main() {
logger.Fatal("Failed to set up metrics", slog.Any("err", err))
}

writer, err := buildDestination(ctx, cfg, statsD)
destination, err := buildDestination(ctx, cfg, statsD)
if err != nil {
logger.Fatal(fmt.Sprintf("Failed to init '%s' destination", cfg.Destination), slog.Any("err", err))
}
Expand All @@ -99,7 +99,7 @@ func main() {
}
defer source.Close()

if err = source.Run(ctx, writer); err != nil {
if err = source.Run(ctx, destination); err != nil {
logger.Fatal("Failed to run",
slog.Any("err", err),
slog.String("source", string(cfg.Source)),
Expand Down
8 changes: 4 additions & 4 deletions sources/dynamodb/shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ import (
"github.com/artie-labs/reader/lib/logger"
)

func (s *StreamStore) ListenToChannel(ctx context.Context, writer destinations.DestinationWriter) {
func (s *StreamStore) ListenToChannel(ctx context.Context, destination destinations.Destination) {
for shard := range s.shardChan {
go s.processShard(ctx, shard, writer)
go s.processShard(ctx, shard, destination)
}
}

func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.Shard, writer destinations.DestinationWriter) {
func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.Shard, destination destinations.Destination) {
var attempts int

// Is there another go-routine processing this shard?
Expand Down Expand Up @@ -97,7 +97,7 @@ func (s *StreamStore) processShard(ctx context.Context, shard *dynamodbstreams.S
messages = append(messages, msg.RawMessage())
}

if err = writer.WriteRawMessages(ctx, messages); err != nil {
if err = destination.WriteRawMessages(ctx, messages); err != nil {
logger.Panic("Failed to publish messages, exiting...", slog.Any("err", err))
}

Expand Down
8 changes: 4 additions & 4 deletions sources/dynamodb/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ func (s *SnapshotStore) Close() error {
return nil
}

func (s *SnapshotStore) Run(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *SnapshotStore) Run(ctx context.Context, destination destinations.Destination) error {
if err := s.scanFilesOverBucket(); err != nil {
return fmt.Errorf("scanning files over bucket failed: %w", err)
}

if err := s.streamAndPublish(ctx, writer); err != nil {
if err := s.streamAndPublish(ctx, destination); err != nil {
return fmt.Errorf("stream and publish failed: %w", err)
}

Expand Down Expand Up @@ -64,7 +64,7 @@ func (s *SnapshotStore) scanFilesOverBucket() error {
return nil
}

func (s *SnapshotStore) streamAndPublish(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *SnapshotStore) streamAndPublish(ctx context.Context, destination destinations.Destination) error {
keys, err := s.retrievePrimaryKeys()
if err != nil {
return fmt.Errorf("failed to retrieve primary keys: %w", err)
Expand Down Expand Up @@ -92,7 +92,7 @@ func (s *SnapshotStore) streamAndPublish(ctx context.Context, writer destination
messages = append(messages, dynamoMsg.RawMessage())
}

if err = writer.WriteRawMessages(ctx, messages); err != nil {
if err = destination.WriteRawMessages(ctx, messages); err != nil {
return fmt.Errorf("failed to publish messages: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions sources/dynamodb/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ func (s *StreamStore) Close() error {
return nil
}

func (s *StreamStore) Run(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *StreamStore) Run(ctx context.Context, destination destinations.Destination) error {
ticker := time.NewTicker(shardScannerInterval)

// Start to subscribe to the channel
go s.ListenToChannel(ctx, writer)
go s.ListenToChannel(ctx, destination)

// Scan it for the first time manually, so we don't have to wait 5 mins
if err := s.scanForNewShards(); err != nil {
Expand Down
7 changes: 5 additions & 2 deletions sources/mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/artie-labs/reader/config"
"github.com/artie-labs/reader/destinations"
"github.com/artie-labs/reader/lib/writer"
)

type Source struct {
Expand Down Expand Up @@ -49,7 +50,9 @@ func (s *Source) Close() error {
return nil
}

func (s *Source) Run(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *Source) Run(ctx context.Context, destination destinations.Destination) error {
_writer := writer.New(destination)

for _, collection := range s.cfg.Collections {
snapshotStartTime := time.Now()

Expand All @@ -60,7 +63,7 @@ func (s *Source) Run(ctx context.Context, writer destinations.DestinationWriter)
)

iterator := newIterator(s.db, collection, s.cfg)
count, err := writer.WriteIterator(ctx, iterator)
count, err := _writer.Write(ctx, iterator)
if err != nil {
return fmt.Errorf("failed to snapshot for collection %s: %w", collection.Name, err)
}
Expand Down
11 changes: 7 additions & 4 deletions sources/mysql/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/artie-labs/reader/destinations"
"github.com/artie-labs/reader/lib/debezium/transformer"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/lib/writer"
"github.com/artie-labs/reader/sources/mysql/adapter"
)

Expand All @@ -37,16 +38,18 @@ func (s Source) Close() error {
return s.db.Close()
}

func (s *Source) Run(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *Source) Run(ctx context.Context, destination destinations.Destination) error {
_writer := writer.New(destination)

for _, tableCfg := range s.cfg.Tables {
if err := s.snapshotTable(ctx, writer, *tableCfg); err != nil {
if err := s.snapshotTable(ctx, _writer, *tableCfg); err != nil {
return err
}
}
return nil
}

func (s Source) snapshotTable(ctx context.Context, writer destinations.DestinationWriter, tableCfg config.MySQLTable) error {
func (s Source) snapshotTable(ctx context.Context, _writer writer.Writer, tableCfg config.MySQLTable) error {
logger := slog.With(slog.String("table", tableCfg.Name), slog.String("database", s.cfg.Database))
snapshotStartTime := time.Now()

Expand All @@ -66,7 +69,7 @@ func (s Source) snapshotTable(ctx context.Context, writer destinations.Destinati
}

logger.Info("Scanning table...", slog.Any("batchSize", tableCfg.GetBatchSize()))
count, err := writer.WriteIterator(ctx, dbzTransformer)
count, err := _writer.Write(ctx, dbzTransformer)
if err != nil {
return fmt.Errorf("failed to snapshot for table %s: %w", tableCfg.Name, err)
}
Expand Down
7 changes: 5 additions & 2 deletions sources/postgres/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/artie-labs/reader/destinations"
"github.com/artie-labs/reader/lib/debezium/transformer"
"github.com/artie-labs/reader/lib/rdbms"
"github.com/artie-labs/reader/lib/writer"
"github.com/artie-labs/reader/sources/postgres/adapter"
)

Expand All @@ -38,7 +39,9 @@ func (s *Source) Close() error {
return s.db.Close()
}

func (s *Source) Run(ctx context.Context, writer destinations.DestinationWriter) error {
func (s *Source) Run(ctx context.Context, destination destinations.Destination) error {
_writer := writer.New(destination)

for _, tableCfg := range s.cfg.Tables {
logger := slog.With(slog.String("schema", tableCfg.Schema), slog.String("table", tableCfg.Name))
snapshotStartTime := time.Now()
Expand All @@ -59,7 +62,7 @@ func (s *Source) Run(ctx context.Context, writer destinations.DestinationWriter)
}

logger.Info("Scanning table...", slog.Any("batchSize", tableCfg.GetBatchSize()))
count, err := writer.WriteIterator(ctx, dbzTransformer)
count, err := _writer.Write(ctx, dbzTransformer)
if err != nil {
return fmt.Errorf("failed to snapshot for table %s: %w", tableCfg.Name, err)
}
Expand Down
2 changes: 1 addition & 1 deletion sources/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (

type Source interface {
Close() error
Run(ctx context.Context, writer destinations.DestinationWriter) error
Run(ctx context.Context, destination destinations.Destination) error
}

0 comments on commit a8c8a43

Please sign in to comment.