Skip to content

Commit

Permalink
Minor Reader Improvements (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Mar 25, 2024
1 parent 77e0014 commit 232cbd1
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 60 deletions.
6 changes: 3 additions & 3 deletions lib/kafkalib/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package kafkalib

import "strings"

func IsExceedMaxMessageBytesErr(err error) bool {
func isExceedMaxMessageBytesErr(err error) bool {
return err != nil && strings.Contains(err.Error(),
"Message Size Too Large: the server has a configurable maximum message size to avoid unbounded memory allocation and the client attempted to produce a message larger than this maximum")
}

// RetryableError - returns true if the error is retryable
// isRetryableError - returns true if the error is retryable
// If it's retryable, you need to reload the Kafka client.
func RetryableError(err error) bool {
func isRetryableError(err error) bool {
return err != nil && strings.Contains(err.Error(), "Topic Authorization Failed: the client is not authorized to access the requested topic")
}
2 changes: 1 addition & 1 deletion lib/kafkalib/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestIsExceedMaxMessageBytesErr(t *testing.T) {
}

for _, tc := range tcs {
actual := IsExceedMaxMessageBytesErr(tc.err)
actual := isExceedMaxMessageBytesErr(tc.err)
assert.Equal(t, tc.expected, actual, tc.err)
}
}
2 changes: 1 addition & 1 deletion lib/kafkalib/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/artie-labs/reader/config"
)

func NewWriter(ctx context.Context, cfg config.Kafka) (*kafka.Writer, error) {
func newWriter(ctx context.Context, cfg config.Kafka) (*kafka.Writer, error) {
slog.Info("Setting kafka bootstrap URLs", slog.Any("urls", cfg.BootstrapAddresses()))
writer := &kafka.Writer{
Addr: kafka.TCP(cfg.BootstrapAddresses()...),
Expand Down
2 changes: 1 addition & 1 deletion lib/kafkalib/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/segmentio/kafka-go"
)

func NewMessage(topic string, partitionKey map[string]any, value any) (kafka.Message, error) {
func newMessage(topic string, partitionKey map[string]any, value any) (kafka.Message, error) {
valueBytes, err := json.Marshal(value)
if err != nil {
return kafka.Message{}, err
Expand Down
2 changes: 1 addition & 1 deletion lib/kafkalib/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestNewMessage(t *testing.T) {
},
}

msg, err := NewMessage("topic", map[string]any{"key": "value"}, payload)
msg, err := newMessage("topic", map[string]any{"key": "value"}, payload)
assert.NoError(t, err)
assert.Equal(t, "topic", msg.Topic)
assert.Equal(t, `{"key":"value"}`, string(msg.Key))
Expand Down
51 changes: 27 additions & 24 deletions lib/kafkalib/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,51 +32,54 @@ func NewBatchWriter(ctx context.Context, cfg config.Kafka, statsD mtr.Client) (*
return nil, fmt.Errorf("kafka topic prefix cannot be empty")
}

writer, err := NewWriter(ctx, cfg)
writer, err := newWriter(ctx, cfg)
if err != nil {
return nil, err
}
return &BatchWriter{writer, cfg, statsD}, nil
}

func (w *BatchWriter) reload(ctx context.Context) error {
func (b *BatchWriter) reload(ctx context.Context) error {
slog.Info("Reloading kafka writer")
if err := w.writer.Close(); err != nil {
if err := b.writer.Close(); err != nil {
return err
}

writer, err := NewWriter(ctx, w.cfg)
writer, err := newWriter(ctx, b.cfg)
if err != nil {
return err
}

w.writer = writer
b.writer = writer
return nil
}

func buildKafkaMessages(cfg *config.Kafka, msgs []lib.RawMessage) ([]kafka.Message, error) {
result := make([]kafka.Message, len(msgs))
for i, msg := range msgs {
topic := fmt.Sprintf("%s.%s", cfg.TopicPrefix, msg.TopicSuffix)
kMsg, err := NewMessage(topic, msg.PartitionKey, msg.GetPayload())
func (b *BatchWriter) buildKafkaMessages(rawMsgs []lib.RawMessage) ([]kafka.Message, error) {
var kafkaMsgs []kafka.Message
for _, rawMsg := range rawMsgs {
topic := fmt.Sprintf("%s.%s", b.cfg.TopicPrefix, rawMsg.TopicSuffix)
kafkaMsg, err := newMessage(topic, rawMsg.PartitionKey, rawMsg.GetPayload())
if err != nil {
return nil, err
}
result[i] = kMsg

kafkaMsgs = append(kafkaMsgs, kafkaMsg)
}
return result, nil

return kafkaMsgs, nil
}

func (w *BatchWriter) WriteRawMessages(ctx context.Context, rawMsgs []lib.RawMessage) error {
msgs, err := buildKafkaMessages(&w.cfg, rawMsgs)
func (b *BatchWriter) WriteRawMessages(ctx context.Context, rawMsgs []lib.RawMessage) error {
kafkaMsgs, err := b.buildKafkaMessages(rawMsgs)
if err != nil {
return fmt.Errorf("failed to encode kafka messages: %w", err)
}
return w.WriteMessages(ctx, msgs)

return b.WriteMessages(ctx, kafkaMsgs)
}

func (w *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) error {
chunkSize := w.cfg.GetPublishSize()
func (b *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) error {
chunkSize := b.cfg.GetPublishSize()
if chunkSize < 1 {
return fmt.Errorf("chunk size is too small")
}
Expand All @@ -103,27 +106,27 @@ func (w *BatchWriter) WriteMessages(ctx context.Context, msgs []kafka.Message) e
)
time.Sleep(sleepDuration)

if RetryableError(kafkaErr) {
if reloadErr := w.reload(ctx); reloadErr != nil {
if isRetryableError(kafkaErr) {
if reloadErr := b.reload(ctx); reloadErr != nil {
slog.Warn("Failed to reload kafka writer", slog.Any("err", reloadErr))
}
}
}

kafkaErr = w.writer.WriteMessages(ctx, chunk...)
kafkaErr = b.writer.WriteMessages(ctx, chunk...)
if kafkaErr == nil {
tags["what"] = "success"
break
}

if IsExceedMaxMessageBytesErr(kafkaErr) {
if isExceedMaxMessageBytesErr(kafkaErr) {
slog.Info("Skipping this chunk since the batch exceeded the server")
kafkaErr = nil
break
}
}

w.statsD.Count("kafka.publish", int64(len(chunk)), tags)
b.statsD.Count("kafka.publish", int64(len(chunk)), tags)
if kafkaErr != nil {
return fmt.Errorf("failed to write message: %w, approxSize: %d", kafkaErr, size.GetApproxSize(chunk))
}
Expand All @@ -136,7 +139,7 @@ type messageIterator interface {
Next() ([]lib.RawMessage, error)
}

func (w *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) (int, error) {
func (b *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) (int, error) {
start := time.Now()
var count int
for iter.HasNext() {
Expand All @@ -145,7 +148,7 @@ func (w *BatchWriter) WriteIterator(ctx context.Context, iter messageIterator) (
return 0, fmt.Errorf("failed to iterate over messages: %w", err)

} else if len(msgs) > 0 {
if err = w.WriteRawMessages(ctx, msgs); err != nil {
if err = b.WriteRawMessages(ctx, msgs); err != nil {
return 0, fmt.Errorf("failed to write messages to kafka: %w", err)
}
count += len(msgs)
Expand Down
24 changes: 12 additions & 12 deletions sources/mysql/adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (

const defaultErrorRetries = 10

type mysqlAdapter struct {
type MySQLAdapter struct {
db *sql.DB
dbName string
table mysql.Table
Expand All @@ -27,32 +27,32 @@ type mysqlAdapter struct {
scannerCfg scan.ScannerConfig
}

func NewMySQLAdapter(db *sql.DB, dbName string, tableCfg config.MySQLTable) (mysqlAdapter, error) {
func NewMySQLAdapter(db *sql.DB, dbName string, tableCfg config.MySQLTable) (MySQLAdapter, error) {
slog.Info("Loading metadata for table")
table, err := mysql.LoadTable(db, tableCfg.Name)
if err != nil {
return mysqlAdapter{}, fmt.Errorf("failed to load metadata for table %s: %w", tableCfg.Name, err)
return MySQLAdapter{}, fmt.Errorf("failed to load metadata for table %s: %w", tableCfg.Name, err)
}

columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys)
if err != nil {
return mysqlAdapter{}, err
return MySQLAdapter{}, err
}

return newMySQLAdapter(db, dbName, *table, columns, tableCfg.ToScannerConfig(defaultErrorRetries))
}

func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []schema.Column, scannerCfg scan.ScannerConfig) (mysqlAdapter, error) {
func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []schema.Column, scannerCfg scan.ScannerConfig) (MySQLAdapter, error) {
fieldConverters := make([]transformer.FieldConverter, len(columns))
for i, col := range columns {
converter, err := valueConverterForType(col.Type, col.Opts)
if err != nil {
return mysqlAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err)
return MySQLAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err)
}
fieldConverters[i] = transformer.FieldConverter{Name: col.Name, ValueConverter: converter}
}

return mysqlAdapter{
return MySQLAdapter{
db: db,
dbName: dbName,
table: table,
Expand All @@ -62,23 +62,23 @@ func newMySQLAdapter(db *sql.DB, dbName string, table mysql.Table, columns []sch
}, nil
}

func (m mysqlAdapter) TableName() string {
func (m MySQLAdapter) TableName() string {
return m.table.Name
}

func (m mysqlAdapter) TopicSuffix() string {
func (m MySQLAdapter) TopicSuffix() string {
return fmt.Sprintf("%s.%s", m.dbName, strings.ReplaceAll(m.table.Name, `"`, ``))
}

func (m mysqlAdapter) FieldConverters() []transformer.FieldConverter {
func (m MySQLAdapter) FieldConverters() []transformer.FieldConverter {
return m.fieldConverters
}

func (m mysqlAdapter) NewIterator() (transformer.RowsIterator, error) {
func (m MySQLAdapter) NewIterator() (transformer.RowsIterator, error) {
return scanner.NewScanner(m.db, m.table, m.columns, m.scannerCfg)
}

func (m mysqlAdapter) PartitionKeys() []string {
func (m MySQLAdapter) PartitionKeys() []string {
return m.table.PrimaryKeys
}

Expand Down
22 changes: 11 additions & 11 deletions sources/postgres/adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,36 @@ import (

const defaultErrorRetries = 10

type postgresAdapter struct {
type PostgresAdapter struct {
db *sql.DB
table postgres.Table
columns []schema.Column
fieldConverters []transformer.FieldConverter
scannerCfg scan.ScannerConfig
}

func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAdapter, error) {
func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (PostgresAdapter, error) {
slog.Info("Loading metadata for table")
table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name)
if err != nil {
return postgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err)
return PostgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err)
}

columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys)
if err != nil {
return postgresAdapter{}, err
return PostgresAdapter{}, err
}

fieldConverters := make([]transformer.FieldConverter, len(columns))
for i, col := range columns {
converter, err := valueConverterForType(col.Type, col.Opts)
if err != nil {
return postgresAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err)
return PostgresAdapter{}, fmt.Errorf("failed to build value converter for column %s: %w", col.Name, err)
}
fieldConverters[i] = transformer.FieldConverter{Name: col.Name, ValueConverter: converter}
}

return postgresAdapter{
return PostgresAdapter{
db: db,
table: *table,
columns: columns,
Expand All @@ -55,23 +55,23 @@ func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (postgresAd
}, nil
}

func (p postgresAdapter) TableName() string {
func (p PostgresAdapter) TableName() string {
return p.table.Name
}

func (p postgresAdapter) TopicSuffix() string {
func (p PostgresAdapter) TopicSuffix() string {
return fmt.Sprintf("%s.%s", p.table.Schema, strings.ReplaceAll(p.table.Name, `"`, ``))
}

func (p postgresAdapter) FieldConverters() []transformer.FieldConverter {
func (p PostgresAdapter) FieldConverters() []transformer.FieldConverter {
return p.fieldConverters
}

func (p postgresAdapter) NewIterator() (transformer.RowsIterator, error) {
func (p PostgresAdapter) NewIterator() (transformer.RowsIterator, error) {
return postgres.NewScanner(p.db, p.table, p.columns, p.scannerCfg)
}

func (p postgresAdapter) PartitionKeys() []string {
func (p PostgresAdapter) PartitionKeys() []string {
return p.table.PrimaryKeys
}

Expand Down
4 changes: 2 additions & 2 deletions sources/postgres/adapter/adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestPostgresAdapter_TableName(t *testing.T) {
Schema: "schema",
Name: "table1",
}
assert.Equal(t, "table1", postgresAdapter{table: table}.TableName())
assert.Equal(t, "table1", PostgresAdapter{table: table}.TableName())
}

func TestPostgresAdapter_TopicSuffix(t *testing.T) {
Expand All @@ -45,7 +45,7 @@ func TestPostgresAdapter_TopicSuffix(t *testing.T) {
}

for _, tc := range tcs {
adapter := postgresAdapter{table: tc.table}
adapter := PostgresAdapter{table: tc.table}
assert.Equal(t, tc.expectedTopicName, adapter.TopicSuffix())
}
}
Expand Down
8 changes: 4 additions & 4 deletions sources/postgres/adapter/transformer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test zero batches
{
dbzTransformer := transformer.NewDebeziumTransformerWithIterator(
postgresAdapter{table: table},
PostgresAdapter{table: table},
&MockRowIterator{batches: [][]map[string]any{}},
)
assert.False(t, dbzTransformer.HasNext())
Expand All @@ -61,7 +61,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test an iterator that returns an error
{
dbzTransformer := transformer.NewDebeziumTransformerWithIterator(
postgresAdapter{table: table},
PostgresAdapter{table: table},
&ErrorRowIterator{},
)

Expand All @@ -73,7 +73,7 @@ func TestDebeziumTransformer(t *testing.T) {
// test two batches each with two rows
{
dbzTransformer := transformer.NewDebeziumTransformerWithIterator(
postgresAdapter{
PostgresAdapter{
table: table,
fieldConverters: []transformer.FieldConverter{
{Name: "a", ValueConverter: converters.StringPassthrough{}},
Expand Down Expand Up @@ -130,7 +130,7 @@ func TestDebeziumTransformer_NilOptionalSchema(t *testing.T) {
}

dbzTransformer := transformer.NewDebeziumTransformerWithIterator(
postgresAdapter{
PostgresAdapter{
table: table,
fieldConverters: []transformer.FieldConverter{
{Name: "user_id", ValueConverter: converters.Int16Passthrough{}},
Expand Down

0 comments on commit 232cbd1

Please sign in to comment.