diff --git a/integration_tests/mongo/main.go b/integration_tests/mongo/main.go index 1f2c948ce..4f1377edd 100644 --- a/integration_tests/mongo/main.go +++ b/integration_tests/mongo/main.go @@ -13,6 +13,7 @@ import ( mongoLib "github.com/artie-labs/reader/sources/mongo" xferMongo "github.com/artie-labs/transfer/lib/cdc/mongo" + "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/transfer/lib/kafkalib" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -123,20 +124,18 @@ func testTypes(ctx context.Context, db *mongo.Database, mongoCfg config.MongoDB) } row := rows[0] - // This should not include the payload field in here. The payload field gets injected in [kafkalib.buildKafkaMessageWrapper] - expectedPartitionKey := map[string]any{"id": `{"$oid":"66a95fae3776c2f21f0ff568"}`} - expectedPkBytes, err := json.Marshal(expectedPartitionKey) - if err != nil { - return fmt.Errorf("failed to marshal expected partition key: %w", err) + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"id": `{"$oid":"66a95fae3776c2f21f0ff568"}`}, } - actualPkBytes, err := json.Marshal(row.PartitionKey()) + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) if err != nil { - return fmt.Errorf("failed to marshal actual partition key: %w", err) + return fmt.Errorf("failed to check partition key difference: %w", err) } - if string(expectedPkBytes) != string(actualPkBytes) { - return fmt.Errorf("partition key %s does not match %s", actualPkBytes, expectedPkBytes) + if !equal { + return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } mongoEvt := utils.GetMongoEvent(row) @@ -163,7 +162,12 @@ func testTypes(ctx context.Context, db *mongo.Database, mongoCfg config.MongoDB) return fmt.Errorf("failed to get event from bytes: %w", err) } - pkMap, err := dbz.GetPrimaryKey(actualPkBytes, kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) + actualPartitionKeyBytes, err := json.Marshal(row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to marshal partition key: %w", err) + } + + pkMap, err := dbz.GetPrimaryKey(actualPartitionKeyBytes, kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) if err != nil { return fmt.Errorf("failed to get primary key: %w", err) } diff --git a/integration_tests/mssql/main.go b/integration_tests/mssql/main.go index 6a652586e..d95d3d986 100644 --- a/integration_tests/mssql/main.go +++ b/integration_tests/mssql/main.go @@ -7,9 +7,9 @@ import ( "errors" "fmt" "log/slog" - "maps" "os" + "github.com/artie-labs/transfer/lib/debezium" "github.com/lmittmann/tint" _ "github.com/microsoft/go-mssqldb" @@ -199,8 +199,17 @@ func testTypes(db *sql.DB, dbName string) error { } row := rows[0] - expectedPartitionKey := map[string]any{"pk": int64(1)} - if !maps.Equal(row.PartitionKey(), expectedPartitionKey) { + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"pk": int64(1)}, + } + + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to check partition key difference: %w", err) + } + + if !equal { return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } diff --git a/integration_tests/mysql/main.go b/integration_tests/mysql/main.go index 7ac269c6b..07c2105be 100644 --- a/integration_tests/mysql/main.go +++ b/integration_tests/mysql/main.go @@ -6,9 +6,9 @@ import ( "errors" "fmt" "log/slog" - "maps" "os" + "github.com/artie-labs/transfer/lib/debezium" "github.com/lmittmann/tint" "github.com/artie-labs/reader/config" @@ -649,8 +649,17 @@ func testTypes(db *sql.DB, dbName string) error { } row := rows[0] - expectedPartitionKey := map[string]any{"pk": int64(1)} - if !maps.Equal(row.PartitionKey(), expectedPartitionKey) { + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"pk": int64(1)}, + } + + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to check partition key difference: %w", err) + } + + if !equal { return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } @@ -781,7 +790,17 @@ func testScan(db *sql.DB, dbName string) error { return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize) } for i, row := range rows { - if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) { + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: expectedPartitionKeys[i], + } + + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to check partition key difference: %w", err) + } + + if !equal { return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i]) } textValue := utils.GetEvent(row).Payload.After["c_text_value"] diff --git a/integration_tests/postgres/main.go b/integration_tests/postgres/main.go index 7e3889a42..b94576de0 100644 --- a/integration_tests/postgres/main.go +++ b/integration_tests/postgres/main.go @@ -7,9 +7,9 @@ import ( "errors" "fmt" "log/slog" - "maps" "os" + "github.com/artie-labs/transfer/lib/debezium" _ "github.com/jackc/pgx/v5/stdlib" "github.com/lmittmann/tint" @@ -846,10 +846,19 @@ func testTypes(db *sql.DB) error { if len(rows) != 1 { return fmt.Errorf("expected one row, got %d", len(rows)) } + row := rows[0] + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"pk": int64(1)}, + } + + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to check partition key difference: %w", err) + } - expectedPartitionKey := map[string]any{"pk": int64(1)} - if !maps.Equal(row.PartitionKey(), expectedPartitionKey) { + if !equal { return fmt.Errorf("partition key %v does not match %v", row.PartitionKey(), expectedPartitionKey) } @@ -980,9 +989,20 @@ func testScan(db *sql.DB) error { return fmt.Errorf("expected %d rows, got %d, batch size %d", len(expectedPartitionKeys), len(rows), batchSize) } for i, row := range rows { - if !maps.Equal(row.PartitionKey(), expectedPartitionKeys[i]) { + expectedPartitionKey := debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: expectedPartitionKeys[i], + } + + equal, err := utils.CheckPartitionKeyDifference(expectedPartitionKey, row.PartitionKey()) + if err != nil { + return fmt.Errorf("failed to check partition key difference: %w", err) + } + + if !equal { return fmt.Errorf("partition keys are different for row %d, batch size %d, %v != %v", i, batchSize, row.PartitionKey(), expectedPartitionKeys[i]) } + textValue := utils.GetEvent(row).Payload.After["c_text_value"] if textValue != expectedValues[i] { return fmt.Errorf("row values are different for row %d, batch size %d, %v != %v", i, batchSize, textValue, expectedValues[i]) diff --git a/integration_tests/utils/utils.go b/integration_tests/utils/utils.go index e509888ea..12f2cff40 100644 --- a/integration_tests/utils/utils.go +++ b/integration_tests/utils/utils.go @@ -1,7 +1,9 @@ package utils import ( + "bytes" "database/sql" + "encoding/json" "fmt" "log/slog" "math/rand/v2" @@ -9,6 +11,7 @@ import ( "github.com/artie-labs/transfer/lib/cdc/mongo" "github.com/artie-labs/transfer/lib/cdc/util" + "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/reader/lib/debezium/transformer" "github.com/artie-labs/reader/lib/kafkalib" @@ -87,3 +90,17 @@ func CheckDifference(name, expected, actual string) bool { fmt.Println("--------------------------------------------------------------------------------") return true } + +func CheckPartitionKeyDifference(expected, actual debezium.PrimaryKeyPayload) (bool, error) { + expectedBytes, err := json.Marshal(expected) + if err != nil { + return false, fmt.Errorf("failed to marshal expected: %w", err) + } + + actualBytes, err := json.Marshal(actual) + if err != nil { + return false, fmt.Errorf("failed to marshal actual: %w", err) + } + + return bytes.Equal(expectedBytes, actualBytes), nil +} diff --git a/lib/debezium/transformer/transformer_test.go b/lib/debezium/transformer/transformer_test.go index e0b09dbe7..9f7d49bab 100644 --- a/lib/debezium/transformer/transformer_test.go +++ b/lib/debezium/transformer/transformer_test.go @@ -245,8 +245,14 @@ func TestDebeziumTransformer_Next(t *testing.T) { rows := results[0] assert.Len(t, rows, 1) rawMessage := rows[0] - assert.Equal(t, Row{"foo": "bar", "qux": 12}, rawMessage.PartitionKey()) - assert.Equal(t, "im-a-little-topic-suffix", rawMessage.TopicSuffix()) + assert.Equal(t, + debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: Row{"foo": "bar", "qux": 12}, + }, + rawMessage.PartitionKey(), + ) + assert.Equal(t, "im-a-little-topic-suffix", rawMessage.Topic("")) payload, isOk := rawMessage.Event().(*util.SchemaEventPayload) assert.True(t, isOk) payload.Payload.Source.TsMs = 12345 // Modify source time since it'll be ~now diff --git a/lib/kafkalib/message.go b/lib/kafkalib/message.go index eff5c0a8a..695c93ca0 100644 --- a/lib/kafkalib/message.go +++ b/lib/kafkalib/message.go @@ -1,6 +1,7 @@ package kafkalib import ( + "fmt" "github.com/artie-labs/transfer/lib/cdc" "github.com/artie-labs/transfer/lib/debezium" ) @@ -8,29 +9,36 @@ import ( type Message struct { topicSuffix string partitionKeySchema debezium.FieldsObject - partitionKey map[string]any + partitionKeyValues map[string]any event cdc.Event } -func NewMessage(topicSuffix string, partitionKeySchema debezium.FieldsObject, partitionKey map[string]any, event cdc.Event) Message { +func NewMessage(topicSuffix string, partitionKeySchema debezium.FieldsObject, partitionKeyValues map[string]any, event cdc.Event) Message { return Message{ topicSuffix: topicSuffix, partitionKeySchema: partitionKeySchema, - partitionKey: partitionKey, + partitionKeyValues: partitionKeyValues, event: event, } } -func (r Message) TopicSuffix() string { - return r.topicSuffix +func (r Message) Topic(prefix string) string { + if prefix == "" { + return r.topicSuffix + } + + return fmt.Sprintf("%s.%s", prefix, r.topicSuffix) } -func (r Message) PartitionKey() map[string]any { - return r.partitionKey +func (r Message) PartitionKey() debezium.PrimaryKeyPayload { + return debezium.PrimaryKeyPayload{ + Schema: r.partitionKeySchema, + Payload: r.partitionKeyValues, + } } -func (r Message) PartitionKeySchema() debezium.FieldsObject { - return r.partitionKeySchema +func (r Message) PartitionKeyValues() map[string]any { + return r.partitionKeyValues } func (r Message) Event() cdc.Event { diff --git a/lib/kafkalib/message_test.go b/lib/kafkalib/message_test.go new file mode 100644 index 000000000..088c2c0e1 --- /dev/null +++ b/lib/kafkalib/message_test.go @@ -0,0 +1,13 @@ +package kafkalib + +import ( + "github.com/artie-labs/transfer/lib/debezium" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMessagePartitionKey(t *testing.T) { + msg := NewMessage("suffix", debezium.FieldsObject{}, nil, nil) + assert.Equal(t, "suffix", msg.Topic(""), "no prefix") + assert.Equal(t, "prefix.suffix", msg.Topic("prefix"), "with prefix") +} diff --git a/lib/kafkalib/writer.go b/lib/kafkalib/writer.go index f96ac8b87..c1913fe93 100644 --- a/lib/kafkalib/writer.go +++ b/lib/kafkalib/writer.go @@ -9,7 +9,6 @@ import ( "time" "github.com/artie-labs/transfer/lib/batch" - "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/retry" "github.com/artie-labs/transfer/lib/typing/columns" @@ -83,18 +82,13 @@ func buildKafkaMessageWrapper(topicPrefix string, rawMessage Message) (KafkaMess return KafkaMessageWrapper{}, err } - pk := debezium.PrimaryKeyPayload{ - Schema: rawMessage.PartitionKeySchema(), - Payload: rawMessage.PartitionKey(), - } - - keyBytes, err := json.Marshal(pk) + keyBytes, err := json.Marshal(rawMessage.PartitionKey()) if err != nil { return KafkaMessageWrapper{}, err } return KafkaMessageWrapper{ - Topic: fmt.Sprintf("%s.%s", topicPrefix, rawMessage.TopicSuffix()), + Topic: rawMessage.Topic(topicPrefix), MessageKey: keyBytes, MessageValue: valueBytes, }, nil diff --git a/sources/postgres/adapter/transformer_test.go b/sources/postgres/adapter/transformer_test.go index f31ca40ad..b673879a8 100644 --- a/sources/postgres/adapter/transformer_test.go +++ b/sources/postgres/adapter/transformer_test.go @@ -2,9 +2,10 @@ package adapter import ( "fmt" - "github.com/artie-labs/transfer/lib/cdc/util" "testing" + "github.com/artie-labs/transfer/lib/cdc/util" + "github.com/artie-labs/transfer/lib/debezium" "github.com/stretchr/testify/assert" "github.com/artie-labs/reader/lib/debezium/converters" @@ -78,20 +79,44 @@ func TestDebeziumTransformer(t *testing.T) { msgs1 := results[0] assert.Len(t, msgs1, 2) - assert.Equal(t, "schema.table", msgs1[0].TopicSuffix()) - assert.Equal(t, map[string]any{"a": "1"}, msgs1[0].PartitionKey()) + assert.Equal(t, "schema.table", msgs1[0].Topic("")) + assert.Equal(t, + debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"a": "1"}, + }, + msgs1[0].PartitionKey(), + ) assert.Equal(t, map[string]any{"a": "1", "b": "11"}, msgs1[0].Event().(*util.SchemaEventPayload).Payload.After) - assert.Equal(t, "schema.table", msgs1[1].TopicSuffix()) - assert.Equal(t, map[string]any{"a": "2"}, msgs1[1].PartitionKey()) + assert.Equal(t, "schema.table", msgs1[1].Topic("")) + assert.Equal(t, + debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"a": "2"}, + }, + msgs1[1].PartitionKey(), + ) assert.Equal(t, map[string]any{"a": "2", "b": "12"}, msgs1[1].Event().(*util.SchemaEventPayload).Payload.After) msgs2 := results[1] assert.Len(t, msgs2, 2) - assert.Equal(t, "schema.table", msgs2[0].TopicSuffix()) - assert.Equal(t, map[string]any{"a": "3"}, msgs2[0].PartitionKey()) + assert.Equal(t, "schema.table", msgs2[0].Topic("")) + assert.Equal(t, + debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"a": "3"}, + }, + msgs2[0].PartitionKey(), + ) assert.Equal(t, map[string]any{"a": "3", "b": "13"}, msgs2[0].Event().(*util.SchemaEventPayload).Payload.After) - assert.Equal(t, "schema.table", msgs2[1].TopicSuffix()) - assert.Equal(t, map[string]any{"a": "4"}, msgs2[1].PartitionKey()) + assert.Equal(t, "schema.table", msgs2[1].Topic("")) + assert.Equal(t, + debezium.PrimaryKeyPayload{ + Schema: debezium.FieldsObject{}, + Payload: map[string]any{"a": "4"}, + }, + msgs2[1].PartitionKey(), + ) assert.Equal(t, map[string]any{"a": "4", "b": "14"}, msgs2[1].Event().(*util.SchemaEventPayload).Payload.After) } } diff --git a/writers/transfer/writer.go b/writers/transfer/writer.go index b7c201dff..43c4114e6 100644 --- a/writers/transfer/writer.go +++ b/writers/transfer/writer.go @@ -122,7 +122,7 @@ func (w *Writer) messageToEvent(message readerKafkaLib.Message) (event.Event, er return event.ToMemoryEvent(evt, partitionKey, w.tc, transferConfig.Replication) } - memoryEvent, err := event.ToMemoryEvent(evt, message.PartitionKey(), w.tc, transferConfig.Replication) + memoryEvent, err := event.ToMemoryEvent(evt, message.PartitionKeyValues(), w.tc, transferConfig.Replication) if err != nil { return event.Event{}, err } diff --git a/writers/writer_test.go b/writers/writer_test.go index 835267347..5bd05a77e 100644 --- a/writers/writer_test.go +++ b/writers/writer_test.go @@ -90,9 +90,9 @@ func TestWriter_Write(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 3, count) assert.Len(t, destination.messages, 3) - assert.Equal(t, destination.messages[0].TopicSuffix(), "a") - assert.Equal(t, destination.messages[1].TopicSuffix(), "b") - assert.Equal(t, destination.messages[2].TopicSuffix(), "c") + assert.Equal(t, destination.messages[0].Topic(""), "a") + assert.Equal(t, destination.messages[1].Topic(""), "b") + assert.Equal(t, destination.messages[2].Topic(""), "c") } { // Destination error