diff --git a/lib/kafkalib/writer.go b/lib/kafkalib/writer.go index 666443e0..d5c917f6 100644 --- a/lib/kafkalib/writer.go +++ b/lib/kafkalib/writer.go @@ -4,13 +4,12 @@ import ( "context" "encoding/json" "fmt" + "github.com/artie-labs/transfer/lib/batch" + "github.com/artie-labs/transfer/lib/retry" "log/slog" - "slices" "time" - "github.com/artie-labs/transfer/lib/jitter" "github.com/artie-labs/transfer/lib/kafkalib" - "github.com/artie-labs/transfer/lib/size" "github.com/segmentio/kafka-go" "github.com/artie-labs/reader/config" @@ -84,88 +83,109 @@ func (b *BatchWriter) reload(ctx context.Context) error { return nil } -func (b *BatchWriter) Write(ctx context.Context, rawMsgs []lib.RawMessage) error { - if len(rawMsgs) == 0 { - return nil +func buildKafkaMessage(topicPrefix string, rawMessage lib.RawMessage) (KafkaMessage, error) { + valueBytes, err := json.Marshal(rawMessage.Event()) + if err != nil { + return KafkaMessage{}, err } - var msgs []kafka.Message - var sampleExecutionTime time.Time - for _, rawMsg := range rawMsgs { - sampleExecutionTime = rawMsg.Event().GetExecutionTime() - kafkaMsg, err := newMessage(b.cfg.TopicPrefix, rawMsg) - if err != nil { - return fmt.Errorf("failed to encode kafka message: %w", err) - } - msgs = append(msgs, kafkaMsg) + keyBytes, err := json.Marshal(rawMessage.PartitionKey()) + if err != nil { + return KafkaMessage{}, err + } + + return KafkaMessage{ + Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix()), + Key: keyBytes, + Value: valueBytes, + }, nil +} + +type KafkaMessage struct { + Topic string `json:"topic"` + Key []byte `json:"key"` + Value []byte `json:"value"` +} + +func (k KafkaMessage) toKafkaMessage() kafka.Message { + return kafka.Message{ + Topic: k.Topic, + Key: k.Key, + Value: k.Value, + } +} + +var encoder = func(msg KafkaMessage) ([]byte, error) { + return json.Marshal(msg) +} + +func (b *BatchWriter) write(ctx context.Context, messages []KafkaMessage, sampleExecutionTime time.Time) error { + retryCfg, err := retry.NewJitterRetryConfig(100, 5000, 10, retry.AlwaysRetry) + if err != nil { + return err } - for batch := range slices.Chunk(msgs, int(b.cfg.GetPublishSize())) { + return batch.BySize[KafkaMessage](messages, int(b.writer.BatchBytes), encoder, func(chunk [][]byte) error { tags := map[string]string{ "what": "error", } - var kafkaErr error - for attempts := range 10 { - if attempts > 0 { - sleepDuration := jitter.Jitter(baseJitterMs, maxJitterMs, attempts-1) - slog.Info("Failed to publish to kafka", - slog.Any("err", kafkaErr), - slog.Int("attempts", attempts), - slog.Duration("sleep", sleepDuration), - ) - time.Sleep(sleepDuration) - - if isRetryableError(kafkaErr) { - if reloadErr := b.reload(ctx); reloadErr != nil { - slog.Warn("Failed to reload kafka writer", slog.Any("err", reloadErr)) - } - } + defer func() { + if b.statsD != nil { + b.statsD.Count("kafka.publish", int64(len(chunk)), tags) + b.statsD.Gauge("kafka.lag_ms", float64(time.Since(sampleExecutionTime).Milliseconds()), tags) } + }() - kafkaErr = b.writer.WriteMessages(ctx, batch...) - if kafkaErr == nil { - tags["what"] = "success" - break + var kafkaMessages []kafka.Message + for _, bytes := range chunk { + var msg KafkaMessage + if err := json.Unmarshal(bytes, &msg); err != nil { + return fmt.Errorf("failed to unmarshal message: %w", err) } - if isExceedMaxMessageBytesErr(kafkaErr) { + kafkaMessages = append(kafkaMessages, msg.toKafkaMessage()) + } + + err = retry.WithRetries(retryCfg, func(_ int, _ error) error { + publishErr := b.writer.WriteMessages(ctx, kafkaMessages...) + if isExceedMaxMessageBytesErr(publishErr) { slog.Info("Skipping this batch since the message size exceeded the server max") - kafkaErr = nil - break + return nil } - } - if b.statsD != nil { - b.statsD.Count("kafka.publish", int64(len(batch)), tags) - b.statsD.Gauge("kafka.lag_ms", float64(time.Since(sampleExecutionTime).Milliseconds()), tags) - } + return publishErr + }) - if kafkaErr != nil { - return fmt.Errorf("failed to write message: %w, approxSize: %d", kafkaErr, size.GetApproxSize(batch)) + if err != nil { + return fmt.Errorf("failed to write messages: %w", err) } - } - return nil -} -func (b *BatchWriter) OnComplete(_ context.Context) error { - return nil + tags["what"] = "success" + return nil + }) } -func newMessage(topicPrefix string, rawMessage lib.RawMessage) (kafka.Message, error) { - valueBytes, err := json.Marshal(rawMessage.Event()) - if err != nil { - return kafka.Message{}, err +func (b *BatchWriter) Write(ctx context.Context, rawMsgs []lib.RawMessage) error { + if len(rawMsgs) == 0 { + return nil } - keyBytes, err := json.Marshal(rawMessage.PartitionKey()) - if err != nil { - return kafka.Message{}, err + var msgs []KafkaMessage + var sampleExecutionTime time.Time + for _, rawMsg := range rawMsgs { + sampleExecutionTime = rawMsg.Event().GetExecutionTime() + msg, err := buildKafkaMessage(b.cfg.TopicPrefix, rawMsg) + if err != nil { + return fmt.Errorf("failed to build kafka message: %w", err) + } + + msgs = append(msgs, msg) } - return kafka.Message{ - Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix()), - Key: keyBytes, - Value: valueBytes, - }, nil + return b.write(ctx, msgs, sampleExecutionTime) +} + +func (b *BatchWriter) OnComplete(_ context.Context) error { + return nil }