From 232cbd1b049b99c410cc173e1f9748923f3ba953 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Mon, 25 Mar 2024 14:17:32 -0400 Subject: [PATCH] Minor Reader Improvements (#315) --- lib/kafkalib/errors.go | 6 +-- lib/kafkalib/errors_test.go | 2 +- lib/kafkalib/kafka.go | 2 +- lib/kafkalib/message.go | 2 +- lib/kafkalib/message_test.go | 2 +- lib/kafkalib/writer.go | 51 +++++++++++--------- sources/mysql/adapter/adapter.go | 24 ++++----- sources/postgres/adapter/adapter.go | 22 ++++----- sources/postgres/adapter/adapter_test.go | 4 +- sources/postgres/adapter/transformer_test.go | 8 +-- 10 files changed, 63 insertions(+), 60 deletions(-) diff --git a/lib/kafkalib/errors.go b/lib/kafkalib/errors.go index 696d132f..ff8fc1d8 100644 --- a/lib/kafkalib/errors.go +++ b/lib/kafkalib/errors.go @@ -2,13 +2,13 @@ package kafkalib import "strings" -func IsExceedMaxMessageBytesErr(err error) bool { +func isExceedMaxMessageBytesErr(err error) bool { return err != nil && strings.Contains(err.Error(), "Message Size Too Large: the server has a configurable maximum message size to avoid unbounded memory allocation and the client attempted to produce a message larger than this maximum") } -// RetryableError - returns true if the error is retryable +// isRetryableError - returns true if the error is retryable // If it's retryable, you need to reload the Kafka client. -func RetryableError(err error) bool { +func isRetryableError(err error) bool { return err != nil && strings.Contains(err.Error(), "Topic Authorization Failed: the client is not authorized to access the requested topic") } diff --git a/lib/kafkalib/errors_test.go b/lib/kafkalib/errors_test.go index 00ff8f54..53ee7973 100644 --- a/lib/kafkalib/errors_test.go +++ b/lib/kafkalib/errors_test.go @@ -27,7 +27,7 @@ func TestIsExceedMaxMessageBytesErr(t *testing.T) { } for _, tc := range tcs { - actual := IsExceedMaxMessageBytesErr(tc.err) + actual := isExceedMaxMessageBytesErr(tc.err) assert.Equal(t, tc.expected, actual, tc.err) } } diff --git a/lib/kafkalib/kafka.go b/lib/kafkalib/kafka.go index 321c0a92..a8695033 100644 --- a/lib/kafkalib/kafka.go +++ b/lib/kafkalib/kafka.go @@ -14,7 +14,7 @@ import ( "github.com/artie-labs/reader/config" ) -func NewWriter(ctx context.Context, cfg config.Kafka) (*kafka.Writer, error) { +func newWriter(ctx context.Context, cfg config.Kafka) (*kafka.Writer, error) { slog.Info("Setting kafka bootstrap URLs", slog.Any("urls", cfg.BootstrapAddresses())) writer := &kafka.Writer{ Addr: kafka.TCP(cfg.BootstrapAddresses()...), diff --git a/lib/kafkalib/message.go b/lib/kafkalib/message.go index 43974b68..51acf258 100644 --- a/lib/kafkalib/message.go +++ b/lib/kafkalib/message.go @@ -6,7 +6,7 @@ import ( "github.com/segmentio/kafka-go" ) -func NewMessage(topic string, partitionKey map[string]any, value any) (kafka.Message, error) { +func newMessage(topic string, partitionKey map[string]any, value any) (kafka.Message, error) { valueBytes, err := json.Marshal(value) if err != nil { return kafka.Message{}, err diff --git a/lib/kafkalib/message_test.go b/lib/kafkalib/message_test.go index 2d51b054..94ebbe8f 100644 --- a/lib/kafkalib/message_test.go +++ b/lib/kafkalib/message_test.go @@ -19,7 +19,7 @@ func TestNewMessage(t *testing.T) { }, } - msg, err := NewMessage("topic", map[string]any{"key": "value"}, payload) + msg, err := newMessage("topic", map[string]any{"key": "value"}, payload) assert.NoError(t, err) assert.Equal(t, "topic", msg.Topic) assert.Equal(t, `{"key":"value"}`, string(msg.Key)) diff --git a/lib/kafkalib/writer.go b/lib/kafkalib/writer.go index 65661da1..1a4f344d 100644 --- a/lib/kafkalib/writer.go +++ b/lib/kafkalib/writer.go @@ -32,51 +32,54 @@ func NewBatchWriter(ctx context.Context, cfg config.Kafka, statsD mtr.Client) (* return nil, fmt.Errorf("kafka topic prefix cannot be empty") } - writer, err := NewWriter(ctx, cfg) + writer, err := newWriter(ctx, cfg) if err != nil { return nil, err } return &BatchWriter{writer, cfg, statsD}, nil } -func (w *BatchWriter) reload(ctx context.Context) error { +func (b *BatchWriter) reload(ctx context.Context) error { slog.Info("Reloading kafka writer") - if err := w.writer.Close(); err != nil { + if err := b.writer.Close(); err != nil { return err } - writer, err := NewWriter(ctx, w.cfg) + writer, err := newWriter(ctx, b.cfg) if err != nil { return err } - w.writer = writer + b.writer = writer return nil } -func buildKafkaMessages(cfg *config.Kafka, msgs []lib.RawMessage) ([]kafka.Message, error) { - result := make([]kafka.Message, len(msgs)) - for i, msg := range msgs { - topic := fmt.Sprintf("%s.%s", cfg.TopicPrefix, msg.TopicSuffix) - kMsg, err := NewMessage(topic, msg.PartitionKey, msg.GetPayload()) +func (b *BatchWriter) buildKafkaMessages(rawMsgs []lib.RawMessage) ([]kafka.Message, error) { + var kafkaMsgs []kafka.Message + for _, rawMsg := range rawMsgs { + topic := fmt.Sprintf("%s.%s", b.cfg.TopicPrefix, rawMsg.TopicSuffix) + kafkaMsg, err := newMessage(topic, rawMsg.PartitionKey, rawMsg.GetPayload()) if err != nil { return nil, err } - result[i] = kMsg + + kafkaMsgs = append(kafkaMsgs, kafkaMsg) } - return result, nil + + return kafkaMsgs, nil } -func (w *BatchWriter) WriteRawMessages(ctx context.Context, rawMsgs []lib.RawMessage) error { - msgs, err := buildKafkaMessages(&w.cfg, rawMsgs) +func (b *BatchWriter) WriteRawMessages(ctx context.Context, rawMsgs []lib.RawMessage) error { + kafkaMsgs, err := b.buildKafkaMessages(rawMsgs) if err != nil { return fmt.Errorf("failed to encode kafka messages: %w", err) } - return w.WriteMessages(ctx, msgs) + + return b.WriteMessages(ctx, kafkaMsgs) } -func (w *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) error { - chunkSize := w.cfg.GetPublishSize() +func (b *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) error { + chunkSize := b.cfg.GetPublishSize() if chunkSize < 1 { return fmt.Errorf("chunk size is too small") } @@ -103,27 +106,27 @@ func (w *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) e ) time.Sleep(sleepDuration) - if RetryableError(kafkaErr) { - if reloadErr := w.reload(ctx); reloadErr != nil { + if isRetryableError(kafkaErr) { + if reloadErr := b.reload(ctx); reloadErr != nil { slog.Warn("Failed to reload kafka writer", slog.Any("err", reloadErr)) } } } - kafkaErr = w.writer.WriteMessages(ctx, chunk...) + kafkaErr = b.writer.WriteMessages(ctx, chunk...) if kafkaErr == nil { tags["what"] = "success" break } - if IsExceedMaxMessageBytesErr(kafkaErr) { + if isExceedMaxMessageBytesErr(kafkaErr) { slog.Info("Skipping this chunk since the batch exceeded the server") kafkaErr = nil break } } - w.statsD.Count("kafka.publish", int64(len(chunk)), tags) + b.statsD.Count("kafka.publish", int64(len(chunk)), tags) if kafkaErr != nil { return fmt.Errorf("failed to write message: %w, approxSize: %d", kafkaErr, size.GetApproxSize(chunk)) } @@ -136,7 +139,7 @@ type messageIterator interface { Next() ([]lib.RawMessage, error) } -func (w *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) (int, error) { +func (b *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) (int, error) { start := time.Now() var count int for iter.HasNext() { @@ -145,7 +148,7 @@ func (w *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) ( return 0, fmt.Errorf("failed to iterate over messages: %w", err) } else if len(msgs) > 0 { - if err = w.WriteRawMessages(ctx, msgs); err != nil { + if err = b.WriteRawMessages(ctx, msgs); err != nil { return 0, fmt.Errorf("failed to write messages to kafka: %w", err) } count += len(msgs) diff --git a/sources/mysql/adapter/adapter.go b/sources/mysql/adapter/adapter.go index 55963b96..124da1ac 100644 --- a/sources/mysql/adapter/adapter.go +++ b/sources/mysql/adapter/adapter.go @@ -18,7 +18,7 @@ import ( const defaultErrorRetries = 10 -type mysqlAdapter struct { +type MySQLAdapter struct { db *sql.DB dbName string table mysql.Table @@ -27,32 +27,32 @@ type mysqlAdapter struct { scannerCfg scan.ScannerConfig } -func NewMySQLAdapter(db *sql.DB, dbName string, tableCfg config.MySQLTable) (mysqlAdapter, error) { +func NewMySQLAdapter(db *sql.DB, dbName string, tableCfg config.MySQLTable) (MySQLAdapter, error) { slog.Info("Loading metadata for table") table, err := mysql.LoadTable(db, tableCfg.Name) if err != nil { - return mysqlAdapter{}, fmt.Errorf("failed to load metadata for table %s: %w", tableCfg.Name, err) + return MySQLAdapter{}, fmt.Errorf("failed to load metadata for table %s: %w", tableCfg.Name, err) } columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys) if err != nil { - return mysqlAdapter{}, err + return MySQLAdapter{}, err } return newMySQLAdapter(db, dbName, *table, columns, tableCfg.ToScannerConfig(defaultErrorRetries)) } -func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []schema.Column, scannerCfg scan.ScannerConfig) (mysqlAdapter, error) { +func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []schema.Column, scannerCfg scan.ScannerConfig) (MySQLAdapter, error) { fieldConverters := make([]transformer.FieldConverter, len(columns)) for i, col := range columns { converter, err := valueConverterForType(col.Type, col.Opts) if err != nil { - return mysqlAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err) + return MySQLAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err) } fieldConverters[i] = transformer.FieldConverter{Name: col.Name, ValueConverter: converter} } - return mysqlAdapter{ + return MySQLAdapter{ db: db, dbName: dbName, table: table, @@ -62,23 +62,23 @@ func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []sch }, nil } -func (m mysqlAdapter) TableName() string { +func (m MySQLAdapter) TableName() string { return m.table.Name } -func (m mysqlAdapter) TopicSuffix() string { +func (m MySQLAdapter) TopicSuffix() string { return fmt.Sprintf("%s.%s", m.dbName, strings.ReplaceAll(m.table.Name, `"`, ``)) } -func (m mysqlAdapter) FieldConverters() []transformer.FieldConverter { +func (m MySQLAdapter) FieldConverters() []transformer.FieldConverter { return m.fieldConverters } -func (m mysqlAdapter) NewIterator() (transformer.RowsIterator, error) { +func (m MySQLAdapter) NewIterator() (transformer.RowsIterator, error) { return scanner.NewScanner(m.db, m.table, m.columns, m.scannerCfg) } -func (m mysqlAdapter) PartitionKeys() []string { +func (m MySQLAdapter) PartitionKeys() []string { return m.table.PrimaryKeys } diff --git a/sources/postgres/adapter/adapter.go b/sources/postgres/adapter/adapter.go index 80b13028..ec32d5df 100644 --- a/sources/postgres/adapter/adapter.go +++ b/sources/postgres/adapter/adapter.go @@ -17,7 +17,7 @@ import ( const defaultErrorRetries = 10 -type postgresAdapter struct { +type PostgresAdapter struct { db *sql.DB table postgres.Table columns []schema.Column @@ -25,28 +25,28 @@ type postgresAdapter struct { scannerCfg scan.ScannerConfig } -func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAdapter, error) { +func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (PostgresAdapter, error) { slog.Info("Loading metadata for table") table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name) if err != nil { - return postgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err) + return PostgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err) } columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys) if err != nil { - return postgresAdapter{}, err + return PostgresAdapter{}, err } fieldConverters := make([]transformer.FieldConverter, len(columns)) for i, col := range columns { converter, err := valueConverterForType(col.Type, col.Opts) if err != nil { - return postgresAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err) + return PostgresAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err) } fieldConverters[i] = transformer.FieldConverter{Name: col.Name, ValueConverter: converter} } - return postgresAdapter{ + return PostgresAdapter{ db: db, table: *table, columns: columns, @@ -55,23 +55,23 @@ func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAd }, nil } -func (p postgresAdapter) TableName() string { +func (p PostgresAdapter) TableName() string { return p.table.Name } -func (p postgresAdapter) TopicSuffix() string { +func (p PostgresAdapter) TopicSuffix() string { return fmt.Sprintf("%s.%s", p.table.Schema, strings.ReplaceAll(p.table.Name, `"`, ``)) } -func (p postgresAdapter) FieldConverters() []transformer.FieldConverter { +func (p PostgresAdapter) FieldConverters() []transformer.FieldConverter { return p.fieldConverters } -func (p postgresAdapter) NewIterator() (transformer.RowsIterator, error) { +func (p PostgresAdapter) NewIterator() (transformer.RowsIterator, error) { return postgres.NewScanner(p.db, p.table, p.columns, p.scannerCfg) } -func (p postgresAdapter) PartitionKeys() []string { +func (p PostgresAdapter) PartitionKeys() []string { return p.table.PrimaryKeys } diff --git a/sources/postgres/adapter/adapter_test.go b/sources/postgres/adapter/adapter_test.go index 4498b0a2..19dce01a 100644 --- a/sources/postgres/adapter/adapter_test.go +++ b/sources/postgres/adapter/adapter_test.go @@ -18,7 +18,7 @@ func TestPostgresAdapter_TableName(t *testing.T) { Schema: "schema", Name: "table1", } - assert.Equal(t, "table1", postgresAdapter{table: table}.TableName()) + assert.Equal(t, "table1", PostgresAdapter{table: table}.TableName()) } func TestPostgresAdapter_TopicSuffix(t *testing.T) { @@ -45,7 +45,7 @@ func TestPostgresAdapter_TopicSuffix(t *testing.T) { } for _, tc := range tcs { - adapter := postgresAdapter{table: tc.table} + adapter := PostgresAdapter{table: tc.table} assert.Equal(t, tc.expectedTopicName, adapter.TopicSuffix()) } } diff --git a/sources/postgres/adapter/transformer_test.go b/sources/postgres/adapter/transformer_test.go index 070d407d..89fd0b12 100644 --- a/sources/postgres/adapter/transformer_test.go +++ b/sources/postgres/adapter/transformer_test.go @@ -52,7 +52,7 @@ func TestDebeziumTransformer(t *testing.T) { // test zero batches { dbzTransformer := transformer.NewDebeziumTransformerWithIterator( - postgresAdapter{table: table}, + PostgresAdapter{table: table}, &MockRowIterator{batches: [][]map[string]any{}}, ) assert.False(t, dbzTransformer.HasNext()) @@ -61,7 +61,7 @@ func TestDebeziumTransformer(t *testing.T) { // test an iterator that returns an error { dbzTransformer := transformer.NewDebeziumTransformerWithIterator( - postgresAdapter{table: table}, + PostgresAdapter{table: table}, &ErrorRowIterator{}, ) @@ -73,7 +73,7 @@ func TestDebeziumTransformer(t *testing.T) { // test two batches each with two rows { dbzTransformer := transformer.NewDebeziumTransformerWithIterator( - postgresAdapter{ + PostgresAdapter{ table: table, fieldConverters: []transformer.FieldConverter{ {Name: "a", ValueConverter: converters.StringPassthrough{}}, @@ -130,7 +130,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) { } dbzTransformer := transformer.NewDebeziumTransformerWithIterator( - postgresAdapter{ + PostgresAdapter{ table: table, fieldConverters: []transformer.FieldConverter{ {Name: "user_id", ValueConverter: converters.Int16Passthrough{}},