From 04b8f737faae58dbe2892fae48c56dad067e0592 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Tue, 18 Apr 2023 22:13:46 -0700 Subject: [PATCH] [Enhancement] Supporting composite primary keys (#77) --- Makefile | 3 +- clients/bigquery/merge.go | 19 ++-- clients/bigquery/merge_test.go | 152 +++++++++++++++++++++++++- clients/snowflake/merge.go | 11 +- clients/snowflake/merge_test.go | 15 ++- clients/snowflake/snowflake_test.go | 11 +- lib/cdc/event.go | 4 +- lib/cdc/mongo/debezium.go | 30 ++--- lib/cdc/mongo/debezium_test.go | 12 +- lib/cdc/mysql/debezium.go | 15 +-- lib/cdc/mysql/debezium_test.go | 5 +- lib/cdc/postgres/debezium.go | 15 +-- lib/cdc/postgres/debezium_test.go | 20 ++-- lib/cdc/util/parser.go | 61 ----------- lib/cdc/util/parser_test.go | 78 ------------- lib/cdc/util/relational_event.go | 7 +- lib/cdc/util/relational_event_test.go | 11 +- lib/debezium/keys.go | 90 +++++++++++++++ lib/debezium/keys_test.go | 114 +++++++++++++++++++ lib/dwh/dml/merge.go | 62 ++++++++--- lib/dwh/dml/merge_test.go | 72 +++++++++++- lib/optimization/event.go | 2 +- models/event.go | 53 ++++++--- models/event_test.go | 82 +++++++++++++- models/flush/flush_test.go | 22 ++-- models/memory.go | 18 +-- models/memory_test.go | 40 ++++++- processes/consumer/process.go | 4 +- processes/consumer/process_test.go | 4 +- 29 files changed, 735 insertions(+), 297 deletions(-) delete mode 100644 lib/cdc/util/parser.go delete mode 100644 lib/cdc/util/parser_test.go create mode 100644 lib/debezium/keys.go create mode 100644 lib/debezium/keys_test.go diff --git a/Makefile b/Makefile index 5b283abf6..6c11ebb55 100644 --- a/Makefile +++ b/Makefile @@ -14,8 +14,9 @@ clean: .PHONY: generate generate: + go get github.com/maxbrunsfeld/counterfeiter/v6 go generate ./... - + go mod tidy .PHONY: build build: goreleaser build --clean diff --git a/clients/bigquery/merge.go b/clients/bigquery/merge.go index 5adbf5c42..271fa6e06 100644 --- a/clients/bigquery/merge.go +++ b/clients/bigquery/merge.go @@ -86,14 +86,17 @@ func merge(tableData *optimization.TableData) (string, error) { subQuery := strings.Join(rowValues, " UNION ALL ") - var specialCastForPrimaryKey bool - pkType, isOk := tableData.InMemoryColumns[tableData.PrimaryKey] - if isOk { - specialCastForPrimaryKey = pkType.Kind == typing.Struct.Kind - } - - return dml.MergeStatement(tableData.ToFqName(constants.BigQuery), subQuery, - tableData.PrimaryKey, tableData.IdempotentKey, cols, tableData.SoftDelete, specialCastForPrimaryKey) + return dml.MergeStatement(dml.MergeArgument{ + FqTableName: tableData.ToFqName(constants.BigQuery), + SubQuery: subQuery, + IdempotentKey: tableData.IdempotentKey, + PrimaryKeys: tableData.PrimaryKeys, + Columns: cols, + ColumnToType: tableData.InMemoryColumns, + SoftDelete: tableData.SoftDelete, + // BigQuery specifically needs it. + SpecialCastingRequired: true, + }) } func (s *Store) Merge(ctx context.Context, tableData *optimization.TableData) error { diff --git a/clients/bigquery/merge_test.go b/clients/bigquery/merge_test.go index 2e655235c..30619a7c2 100644 --- a/clients/bigquery/merge_test.go +++ b/clients/bigquery/merge_test.go @@ -19,7 +19,7 @@ func (b *BigQueryTestSuite) TestMergeNoDeleteFlag() { tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: nil, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, TopicConfig: kafkalib.TopicConfig{}, LatestCDCTs: time.Time{}, } @@ -29,6 +29,8 @@ func (b *BigQueryTestSuite) TestMergeNoDeleteFlag() { } func (b *BigQueryTestSuite) TestMerge() { + primaryKeys := []string{"id"} + cols := map[string]typing.KindDetails{ "id": typing.Integer, "name": typing.String, @@ -54,7 +56,7 @@ func (b *BigQueryTestSuite) TestMerge() { tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: rowData, - PrimaryKey: "id", + PrimaryKeys: primaryKeys, TopicConfig: topicConfig, LatestCDCTs: time.Time{}, } @@ -65,7 +67,11 @@ func (b *BigQueryTestSuite) TestMerge() { // Check if MERGE INTO FQ Table exists. assert.True(b.T(), strings.Contains(mergeSQL, "MERGE INTO shop.customer c"), mergeSQL) // Check for equality merge - assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", tableData.PrimaryKey, tableData.PrimaryKey))) + + for _, pk := range primaryKeys { + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", pk, pk))) + } + for _, rowData := range tableData.RowsData { for col, val := range rowData { switch cols[col] { @@ -107,10 +113,139 @@ func (b *BigQueryTestSuite) TestMergeJSONKey() { Schema: "public", } + primaryKeys := []string{"id"} + + tableData := &optimization.TableData{ + InMemoryColumns: cols, + RowsData: rowData, + PrimaryKeys: primaryKeys, + TopicConfig: topicConfig, + LatestCDCTs: time.Time{}, + } + + mergeSQL, err := merge(tableData) + + assert.NoError(b.T(), err, "merge failed") + // Check if MERGE INTO FQ Table exists. + assert.True(b.T(), strings.Contains(mergeSQL, "MERGE INTO shop.customer c"), mergeSQL) + // Check for equality merge + + for _, primaryKey := range primaryKeys { + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey, primaryKey))) + } + + for _, rowData := range tableData.RowsData { + for col, val := range rowData { + switch cols[col] { + case typing.String, typing.Array, typing.Struct: + val = fmt.Sprintf("'%v'", val) + } + + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprint(val)), map[string]interface{}{ + "merge": mergeSQL, + "val": val, + }) + } + } +} + +func (b *BigQueryTestSuite) TestMergeSimpleCompositeKey() { + cols := map[string]typing.KindDetails{ + "id": typing.String, + "idA": typing.String, + "name": typing.String, + constants.DeleteColumnMarker: typing.Boolean, + } + + rowData := make(map[string]map[string]interface{}) + for idx, name := range []string{"robin", "jacqueline", "dusty"} { + pkVal := fmt.Sprint(map[string]interface{}{ + "$oid": fmt.Sprintf("640127e4beeb1ccfc821c25c++%v", idx), + }) + + rowData[pkVal] = map[string]interface{}{ + "id": pkVal, + "name": name, + constants.DeleteColumnMarker: false, + } + } + + topicConfig := kafkalib.TopicConfig{ + Database: "shop", + TableName: "customer", + Schema: "public", + } + + primaryKeys := []string{"id", "idA"} + tableData := &optimization.TableData{ + InMemoryColumns: cols, + RowsData: rowData, + PrimaryKeys: primaryKeys, + TopicConfig: topicConfig, + LatestCDCTs: time.Time{}, + } + + mergeSQL, err := merge(tableData) + + assert.NoError(b.T(), err, "merge failed") + // Check if MERGE INTO FQ Table exists. + assert.True(b.T(), strings.Contains(mergeSQL, "MERGE INTO shop.customer c"), mergeSQL) + // Check for equality merge + for _, primaryKey := range primaryKeys { + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", primaryKey, primaryKey))) + } + + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s and c.%s = cc.%s", "id", "id", "idA", "idA")), mergeSQL) + for _, rowData := range tableData.RowsData { + for col, val := range rowData { + switch cols[col] { + case typing.String, typing.Array, typing.Struct: + val = fmt.Sprintf("'%v'", val) + } + + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprint(val)), map[string]interface{}{ + "merge": mergeSQL, + "val": val, + }) + } + } +} + +func (b *BigQueryTestSuite) TestMergeJSONKeyAndCompositeHybrid() { + cols := map[string]typing.KindDetails{ + "id": typing.Struct, + "idA": typing.String, + "idB": typing.String, + "idC": typing.Struct, + "name": typing.String, + constants.DeleteColumnMarker: typing.Boolean, + } + + rowData := make(map[string]map[string]interface{}) + for idx, name := range []string{"robin", "jacqueline", "dusty"} { + pkVal := fmt.Sprint(map[string]interface{}{ + "$oid": fmt.Sprintf("640127e4beeb1ccfc821c25c++%v", idx), + }) + + rowData[pkVal] = map[string]interface{}{ + "id": pkVal, + "name": name, + constants.DeleteColumnMarker: false, + } + } + + topicConfig := kafkalib.TopicConfig{ + Database: "shop", + TableName: "customer", + Schema: "public", + } + + primaryKeys := []string{"id", "idA", "idB", "idC"} + tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: rowData, - PrimaryKey: "id", + PrimaryKeys: primaryKeys, TopicConfig: topicConfig, LatestCDCTs: time.Time{}, } @@ -121,7 +256,14 @@ func (b *BigQueryTestSuite) TestMergeJSONKey() { // Check if MERGE INTO FQ Table exists. assert.True(b.T(), strings.Contains(mergeSQL, "MERGE INTO shop.customer c"), mergeSQL) // Check for equality merge - assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", tableData.PrimaryKey, tableData.PrimaryKey))) + for _, primaryKey := range []string{"id", "idC"} { + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey, primaryKey)), mergeSQL) + } + + for _, primaryKey := range []string{"idA", "idB"} { + assert.True(b.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", primaryKey, primaryKey))) + } + for _, rowData := range tableData.RowsData { for col, val := range rowData { switch cols[col] { diff --git a/clients/snowflake/merge.go b/clients/snowflake/merge.go index 91a83dc05..b484e1f76 100644 --- a/clients/snowflake/merge.go +++ b/clients/snowflake/merge.go @@ -81,6 +81,13 @@ func getMergeStatement(tableData *optimization.TableData) (string, error) { subQuery := fmt.Sprintf("SELECT %s FROM (values %s) as %s(%s)", strings.Join(sflkCols, ","), strings.Join(tableValues, ","), tableData.TopicConfig.TableName, strings.Join(cols, ",")) - return dml.MergeStatement(tableData.ToFqName(constants.Snowflake), subQuery, - tableData.PrimaryKey, tableData.IdempotentKey, cols, tableData.SoftDelete, false) + return dml.MergeStatement(dml.MergeArgument{ + FqTableName: tableData.ToFqName(constants.Snowflake), + SubQuery: subQuery, + IdempotentKey: tableData.IdempotentKey, + PrimaryKeys: tableData.PrimaryKeys, + Columns: cols, + ColumnToType: tableData.InMemoryColumns, + SoftDelete: tableData.SoftDelete, + }) } diff --git a/clients/snowflake/merge_test.go b/clients/snowflake/merge_test.go index 9236ef50e..f2f65fc86 100644 --- a/clients/snowflake/merge_test.go +++ b/clients/snowflake/merge_test.go @@ -21,7 +21,7 @@ func (s *SnowflakeTestSuite) TestMergeNoDeleteFlag() { tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: nil, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, TopicConfig: kafkalib.TopicConfig{}, LatestCDCTs: time.Time{}, } @@ -54,10 +54,11 @@ func (s *SnowflakeTestSuite) TestMerge() { Schema: "public", } + primaryKeys := []string{"id"} tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: rowData, - PrimaryKey: "id", + PrimaryKeys: primaryKeys, TopicConfig: topicConfig, LatestCDCTs: time.Time{}, } @@ -71,7 +72,11 @@ func (s *SnowflakeTestSuite) TestMerge() { // Check if MERGE INTO FQ Table exists. assert.True(s.T(), strings.Contains(mergeSQL, "MERGE INTO shop.public.customer c")) - assert.True(s.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", tableData.PrimaryKey, tableData.PrimaryKey))) + + for _, primaryKey := range primaryKeys { + assert.True(s.T(), strings.Contains(mergeSQL, fmt.Sprintf("c.%s = cc.%s", primaryKey, primaryKey))) + } + for _, rowData := range tableData.RowsData { for col, val := range rowData { switch cols[col] { @@ -110,7 +115,7 @@ func (s *SnowflakeTestSuite) TestMergeWithSingleQuote() { tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: rowData, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, TopicConfig: topicConfig, LatestCDCTs: time.Time{}, } @@ -143,7 +148,7 @@ func (s *SnowflakeTestSuite) TestMergeJson() { tableData := &optimization.TableData{ InMemoryColumns: cols, RowsData: rowData, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, TopicConfig: topicConfig, LatestCDCTs: time.Time{}, } diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index a9933b1db..99c02d945 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -21,6 +21,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { // TableData will think the column is invalid and tableConfig will think column = string // Before we call merge, it should reconcile it. columns := map[string]typing.KindDetails{ + "id": typing.String, "first_name": typing.String, "invalid_column": typing.Invalid, constants.DeleteColumnMarker: typing.Boolean, @@ -42,12 +43,13 @@ func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { InMemoryColumns: columns, RowsData: rowsData, TopicConfig: topicConfig, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, Rows: 1, } s.store.configMap.AddTableToConfig(topicConfig.ToFqName(constants.Snowflake), types.NewDwhTableConfig( map[string]typing.KindDetails{ + "id": typing.String, "first_name": typing.String, constants.DeleteColumnMarker: typing.Boolean, }, nil, false, true)) @@ -86,7 +88,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeReestablishAuth() { InMemoryColumns: columns, RowsData: rowsData, TopicConfig: topicConfig, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, Rows: 1, } @@ -132,7 +134,7 @@ func (s *SnowflakeTestSuite) TestExecuteMerge() { InMemoryColumns: columns, RowsData: rowsData, TopicConfig: topicConfig, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, Rows: 1, } @@ -177,7 +179,7 @@ func (s *SnowflakeTestSuite) TestExecuteMergeDeletionFlagRemoval() { InMemoryColumns: columns, RowsData: rowsData, TopicConfig: topicConfig, - PrimaryKey: "id", + PrimaryKeys: []string{"id"}, Rows: 1, } @@ -228,7 +230,6 @@ func (s *SnowflakeTestSuite) TestExecuteMergeExitEarly() { err := s.store.Merge(s.ctx, &optimization.TableData{ InMemoryColumns: nil, RowsData: nil, - PrimaryKey: "", TopicConfig: kafkalib.TopicConfig{}, PartitionsToLastMessage: nil, LatestCDCTs: time.Time{}, diff --git a/lib/cdc/event.go b/lib/cdc/event.go index 96a76eaca..d0123948d 100644 --- a/lib/cdc/event.go +++ b/lib/cdc/event.go @@ -9,13 +9,13 @@ import ( type Format interface { Labels() []string // Labels() to return a list of strings to maintain backward compatibility. - GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (string, interface{}, error) + GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (map[string]interface{}, error) GetEventFromBytes(ctx context.Context, bytes []byte) (Event, error) } type Event interface { GetExecutionTime() time.Time - GetData(ctx context.Context, pkName string, pkVal interface{}, config *kafkalib.TopicConfig) map[string]interface{} + GetData(ctx context.Context, pkMap map[string]interface{}, config *kafkalib.TopicConfig) map[string]interface{} } // FieldLabelKind is used when the schema is turned on. Each schema object will be labelled. diff --git a/lib/cdc/mongo/debezium.go b/lib/cdc/mongo/debezium.go index 4558946f0..611a6ca99 100644 --- a/lib/cdc/mongo/debezium.go +++ b/lib/cdc/mongo/debezium.go @@ -3,12 +3,11 @@ package mongo import ( "context" "encoding/json" - "fmt" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/debezium" "time" "github.com/artie-labs/transfer/lib/cdc" - "github.com/artie-labs/transfer/lib/cdc/util" "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/mongo" @@ -67,26 +66,15 @@ func (d *Debezium) Labels() []string { return []string{constants.DBZMongoFormat} } -// GetPrimaryKey - Will read from the Kafka message's partition key to get the primary key for the row. -// TODO: This should support: key.converter.schemas.enable=true -func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (pkName string, pkValue interface{}, err error) { - switch tc.CDCKeyFormat { - case "org.apache.kafka.connect.json.JsonConverter": - return util.ParseJSONKey(key) - case "org.apache.kafka.connect.storage.StringConverter": - return util.ParseStringKey(key) - default: - err = fmt.Errorf("format: %s is not supported", tc.CDCKeyFormat) - } - - return +func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (kvMap map[string]interface{}, err error) { + return debezium.ParsePartitionKey(key, tc.CDCKeyFormat) } func (s *SchemaEventPayload) GetExecutionTime() time.Time { return time.UnixMilli(s.Payload.Source.TsMs).UTC() } -func (s *SchemaEventPayload) GetData(ctx context.Context, pkName string, pkVal interface{}, tc *kafkalib.TopicConfig) map[string]interface{} { +func (s *SchemaEventPayload) GetData(ctx context.Context, pkMap map[string]interface{}, tc *kafkalib.TopicConfig) map[string]interface{} { retMap := make(map[string]interface{}) if len(s.Payload.AfterMap) == 0 { // This is a delete event, so mark it as deleted. @@ -95,7 +83,10 @@ func (s *SchemaEventPayload) GetData(ctx context.Context, pkName string, pkVal i // the PK. We can explore simplifying this interface in the future by leveraging before. retMap = map[string]interface{}{ constants.DeleteColumnMarker: true, - pkName: pkVal, + } + + for k, v := range pkMap { + retMap[k] = v } // If idempotency key is an empty string, don't put it in the event data @@ -106,7 +97,10 @@ func (s *SchemaEventPayload) GetData(ctx context.Context, pkName string, pkVal i retMap = s.Payload.AfterMap // We need this because there's an edge case with Debezium // Where _id gets rewritten as id in the partition key. - retMap[pkName] = pkVal + for k, v := range pkMap { + retMap[k] = v + } + retMap[constants.DeleteColumnMarker] = false } diff --git a/lib/cdc/mongo/debezium_test.go b/lib/cdc/mongo/debezium_test.go index eb8ab5b85..af70d127b 100644 --- a/lib/cdc/mongo/debezium_test.go +++ b/lib/cdc/mongo/debezium_test.go @@ -3,7 +3,6 @@ package mongo import ( "context" "encoding/json" - "fmt" "github.com/artie-labs/transfer/lib/cdc" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/debezium" @@ -21,9 +20,10 @@ func (p *MongoTestSuite) TestGetPrimaryKey() { CDCKeyFormat: "org.apache.kafka.connect.storage.StringConverter", } - pkName, pkVal, err := p.GetPrimaryKey(context.Background(), []byte(valString), tc) - assert.Equal(p.T(), pkName, "id") - assert.Equal(p.T(), fmt.Sprint(pkVal), fmt.Sprint(1001)) // Don't have to deal with float and int conversion + pkMap, err := p.GetPrimaryKey(context.Background(), []byte(valString), tc) + pkVal, isOk := pkMap["id"] + assert.True(p.T(), isOk) + assert.Equal(p.T(), pkVal, "1001") assert.Equal(p.T(), err, nil) } @@ -127,7 +127,7 @@ func (p *MongoTestSuite) TestMongoDBEventCustomer() { evt, err := p.Debezium.GetEventFromBytes(ctx, []byte(payload)) assert.NoError(p.T(), err) - evtData := evt.GetData(context.Background(), "_id", 1003, &kafkalib.TopicConfig{}) + evtData := evt.GetData(context.Background(), map[string]interface{}{"_id": 1003}, &kafkalib.TopicConfig{}) assert.Equal(p.T(), evtData["_id"], 1003) assert.Equal(p.T(), evtData["first_name"], "Robin") @@ -179,7 +179,7 @@ func (p *MongoTestSuite) TestMongoDBEventCustomerBefore() { evt, err := p.Debezium.GetEventFromBytes(ctx, []byte(payload)) assert.NoError(p.T(), err) - evtData := evt.GetData(context.Background(), "_id", 1003, &kafkalib.TopicConfig{}) + evtData := evt.GetData(context.Background(), map[string]interface{}{"_id": 1003}, &kafkalib.TopicConfig{}) assert.Equal(p.T(), evtData["_id"], 1003) assert.Equal(p.T(), evtData[constants.DeleteColumnMarker], true) diff --git a/lib/cdc/mysql/debezium.go b/lib/cdc/mysql/debezium.go index 9c5ad2507..894f393d8 100644 --- a/lib/cdc/mysql/debezium.go +++ b/lib/cdc/mysql/debezium.go @@ -3,10 +3,10 @@ package mysql import ( "context" "encoding/json" - "fmt" "github.com/artie-labs/transfer/lib/cdc" "github.com/artie-labs/transfer/lib/cdc/util" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/transfer/lib/kafkalib" ) @@ -31,15 +31,6 @@ func (d *Debezium) Labels() []string { return []string{constants.DBZMySQLFormat} } -func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (pkName string, pkValue interface{}, err error) { - switch tc.CDCKeyFormat { - case "org.apache.kafka.connect.json.JsonConverter": - return util.ParseJSONKey(key) - case "org.apache.kafka.connect.storage.StringConverter": - return util.ParseStringKey(key) - default: - err = fmt.Errorf("format: %s is not supported", tc.CDCKeyFormat) - } - - return +func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (kvMap map[string]interface{}, err error) { + return debezium.ParsePartitionKey(key, tc.CDCKeyFormat) } diff --git a/lib/cdc/mysql/debezium_test.go b/lib/cdc/mysql/debezium_test.go index 66872fab3..71fce5e82 100644 --- a/lib/cdc/mysql/debezium_test.go +++ b/lib/cdc/mysql/debezium_test.go @@ -289,7 +289,10 @@ func (m *MySQLTestSuite) TestGetEventFromBytes() { assert.NoError(m.T(), err) assert.Equal(m.T(), time.Date(2023, time.March, 13, 19, 19, 24, 0, time.UTC), evt.GetExecutionTime()) - evtData := evt.GetData(context.Background(), "id", 1001, &kafkalib.TopicConfig{}) + kvMap := map[string]interface{}{ + "id": 1001, + } + evtData := evt.GetData(context.Background(), kvMap, &kafkalib.TopicConfig{}) assert.Equal(m.T(), evtData["id"], 1001) assert.Equal(m.T(), evtData["first_name"], "Sally") assert.Equal(m.T(), evtData["bool_test"], false) diff --git a/lib/cdc/postgres/debezium.go b/lib/cdc/postgres/debezium.go index 6b01fb690..f62656c06 100644 --- a/lib/cdc/postgres/debezium.go +++ b/lib/cdc/postgres/debezium.go @@ -3,10 +3,10 @@ package postgres import ( "context" "encoding/json" - "fmt" "github.com/artie-labs/transfer/lib/cdc" "github.com/artie-labs/transfer/lib/cdc/util" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/transfer/lib/kafkalib" ) @@ -31,15 +31,6 @@ func (d *Debezium) Labels() []string { return []string{constants.DBZPostgresFormat, constants.DBZPostgresAltFormat} } -func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (pkName string, pkValue interface{}, err error) { - switch tc.CDCKeyFormat { - case "org.apache.kafka.connect.json.JsonConverter": - return util.ParseJSONKey(key) - case "org.apache.kafka.connect.storage.StringConverter": - return util.ParseStringKey(key) - default: - err = fmt.Errorf("format: %s is not supported", tc.CDCKeyFormat) - } - - return +func (d *Debezium) GetPrimaryKey(ctx context.Context, key []byte, tc *kafkalib.TopicConfig) (kvMap map[string]interface{}, err error) { + return debezium.ParsePartitionKey(key, tc.CDCKeyFormat) } diff --git a/lib/cdc/postgres/debezium_test.go b/lib/cdc/postgres/debezium_test.go index 93a996a73..7b7e17ba2 100644 --- a/lib/cdc/postgres/debezium_test.go +++ b/lib/cdc/postgres/debezium_test.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "fmt" "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/ext" @@ -18,17 +17,20 @@ var validTc = &kafkalib.TopicConfig{ func (p *PostgresTestSuite) TestGetPrimaryKey() { valString := `{"id": 47}` - pkName, pkVal, err := p.GetPrimaryKey(context.Background(), []byte(valString), validTc) - assert.Equal(p.T(), pkName, "id") - assert.Equal(p.T(), fmt.Sprint(pkVal), fmt.Sprint(47)) // Don't have to deal with float and int conversion + pkMap, err := p.GetPrimaryKey(context.Background(), []byte(valString), validTc) + + val, isOk := pkMap["id"] + assert.True(p.T(), isOk) + assert.Equal(p.T(), val, float64(47)) assert.Equal(p.T(), err, nil) } func (p *PostgresTestSuite) TestGetPrimaryKeyUUID() { valString := `{"uuid": "ca0cefe9-45cf-44fa-a2ab-ec5e7e5522a3"}` - pkName, pkVal, err := p.GetPrimaryKey(context.Background(), []byte(valString), validTc) - assert.Equal(p.T(), pkName, "uuid") - assert.Equal(p.T(), fmt.Sprint(pkVal), "ca0cefe9-45cf-44fa-a2ab-ec5e7e5522a3") + pkMap, err := p.GetPrimaryKey(context.Background(), []byte(valString), validTc) + val, isOk := pkMap["uuid"] + assert.True(p.T(), isOk) + assert.Equal(p.T(), val, "ca0cefe9-45cf-44fa-a2ab-ec5e7e5522a3") assert.Equal(p.T(), err, nil) } @@ -76,7 +78,7 @@ func (p *PostgresTestSuite) TestPostgresEvent() { evt, err := p.Debezium.GetEventFromBytes(context.Background(), []byte(payload)) assert.Nil(p.T(), err) - evtData := evt.GetData(context.Background(), "id", 59, &kafkalib.TopicConfig{}) + evtData := evt.GetData(context.Background(), map[string]interface{}{"id": 59}, &kafkalib.TopicConfig{}) assert.Equal(p.T(), evtData["id"], float64(59)) assert.Equal(p.T(), evtData["item"], "Barings Participation Investors") @@ -177,7 +179,7 @@ func (p *PostgresTestSuite) TestPostgresEventWithSchemaAndTimestampNoTZ() { evt, err := p.Debezium.GetEventFromBytes(context.Background(), []byte(payload)) assert.Nil(p.T(), err) - evtData := evt.GetData(context.Background(), "id", 1001, &kafkalib.TopicConfig{}) + evtData := evt.GetData(context.Background(), map[string]interface{}{"id": 1001}, &kafkalib.TopicConfig{}) // Testing typing. assert.Equal(p.T(), evtData["id"], 1001) diff --git a/lib/cdc/util/parser.go b/lib/cdc/util/parser.go deleted file mode 100644 index 5eb666f18..000000000 --- a/lib/cdc/util/parser.go +++ /dev/null @@ -1,61 +0,0 @@ -package util - -import ( - "encoding/json" - "fmt" - "strings" -) - -// ParseStringKey expects the key format to look like Struct{id=47} -func ParseStringKey(key []byte) (pkName string, pkValue interface{}, err error) { - if len(key) == 0 { - err = fmt.Errorf("key is nil") - return - } - - keyString := string(key) - if len(keyString) < 8 { - return "", "", - fmt.Errorf("key length too short, actual: %v, key: %s", len(keyString), keyString) - } - - // Strip out the leading Struct{ and trailing } - pkParts := strings.Split(keyString[7:len(keyString)-1], "=") - if len(pkParts) != 2 { - return "", "", fmt.Errorf("key length incorrect, actual: %v, key: %s", len(keyString), keyString) - } - - return pkParts[0], pkParts[1], nil -} - -func ParseJSONKey(key []byte) (pkName string, pkValue interface{}, err error) { - if len(key) == 0 { - err = fmt.Errorf("key is nil") - return - } - - var pkStruct map[string]interface{} - err = json.Unmarshal(key, &pkStruct) - if err != nil { - return - } - - _, isOk := pkStruct["payload"] - if isOk { - var castOk bool - // strip the schema and focus in on payload - pkStruct, castOk = pkStruct["payload"].(map[string]interface{}) - if !castOk { - return "", "", fmt.Errorf("key object is malformated") - } - } - - // Given that this is the format, we will only have 1 key in here. - for k, v := range pkStruct { - pkName = k - pkValue = v - break - } - - return -} diff --git a/lib/cdc/util/parser_test.go b/lib/cdc/util/parser_test.go deleted file mode 100644 index 65d813383..000000000 --- a/lib/cdc/util/parser_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package util - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestParseStringKey(t *testing.T) { - _, _, err := ParseStringKey(nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "key is nil") - - pkName, pkVal, err := ParseStringKey([]byte("Struct{id=47}")) - assert.Nil(t, err) - assert.Equal(t, pkName, "id") - assert.Equal(t, pkVal, "47") - - pkName, pkVal, err = ParseStringKey([]byte("Struct{uuid=d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c}")) - assert.Nil(t, err) - assert.Equal(t, pkName, "uuid") - assert.Equal(t, pkVal, "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c") - - _, _, err = ParseStringKey([]byte("{id=")) - assert.Error(t, err) - assert.Contains(t, err.Error(), "key length too short") - - _, _, err = ParseStringKey([]byte("Struct{id=")) - assert.Error(t, err) - assert.Contains(t, err.Error(), "key length incorrect") -} - -func TestParseJSONKey(t *testing.T) { - _, _, err := ParseJSONKey(nil) - assert.Error(t, err) - assert.Contains(t, err.Error(), "key is nil") - - pkName, pkVal, err := ParseJSONKey([]byte(`{"id": 47}`)) - assert.Nil(t, err) - assert.Equal(t, pkName, "id") - assert.Equal(t, pkVal, float64(47)) - - pkName, pkVal, err = ParseJSONKey([]byte(`{"uuid": "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c"}`)) - assert.Nil(t, err) - assert.Equal(t, pkName, "uuid") - assert.Equal(t, pkVal, "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c") - - _, _, err = ParseJSONKey([]byte("{id:")) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid character") - - _, _, err = ParseJSONKey([]byte(`{"id":`)) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unexpected end of JSON") -} - -func TestParseJSONKeyWithSchema(t *testing.T) { - pkName, pkVal, err := ParseJSONKey([]byte(`{ - "schema": { - "type": "struct", - "fields": [{ - "type": "int32", - "optional": false, - "default": 0, - "field": "id" - }], - "optional": false, - "name": "dbserver1.inventory.customers.Key" - }, - "payload": { - "id": 1002 - } -}`)) - - assert.NoError(t, err) - assert.Equal(t, "id", pkName) - assert.Equal(t, float64(1002), pkVal) -} diff --git a/lib/cdc/util/relational_event.go b/lib/cdc/util/relational_event.go index 5d84296b1..224dea35b 100644 --- a/lib/cdc/util/relational_event.go +++ b/lib/cdc/util/relational_event.go @@ -37,7 +37,7 @@ func (s *SchemaEventPayload) GetExecutionTime() time.Time { return time.UnixMilli(s.Payload.Source.TsMs).UTC() } -func (s *SchemaEventPayload) GetData(ctx context.Context, pkName string, pkVal interface{}, tc *kafkalib.TopicConfig) map[string]interface{} { +func (s *SchemaEventPayload) GetData(ctx context.Context, pkMap map[string]interface{}, tc *kafkalib.TopicConfig) map[string]interface{} { retMap := make(map[string]interface{}) if len(s.Payload.After) == 0 { // This is a delete payload, so mark it as deleted. @@ -46,7 +46,10 @@ func (s *SchemaEventPayload) GetData(ctx context.Context, pkName string, pkVal i // the PK. We can explore simplifying this interface in the future by leveraging before. retMap = map[string]interface{}{ constants.DeleteColumnMarker: true, - pkName: pkVal, + } + + for k, v := range pkMap { + retMap[k] = v } // If idempotency key is an empty string, don't put it in the payload data diff --git a/lib/cdc/util/relational_event_test.go b/lib/cdc/util/relational_event_test.go index a036dac27..7081cecd2 100644 --- a/lib/cdc/util/relational_event_test.go +++ b/lib/cdc/util/relational_event_test.go @@ -40,7 +40,7 @@ func TestGetDataTestInsert(t *testing.T) { }, } - evtData := schemaEventPayload.GetData(context.Background(), "pk", 1, &tc) + evtData := schemaEventPayload.GetData(context.Background(), map[string]interface{}{"pk": 1}, &tc) assert.Equal(t, len(after), len(evtData), "has deletion flag") deletionFlag, isOk := evtData[constants.DeleteColumnMarker] @@ -66,7 +66,8 @@ func TestGetDataTestDelete(t *testing.T) { }, } - evtData := schemaEventPayload.GetData(context.Background(), "pk", 1, tc) + kvMap := map[string]interface{}{"pk": 1} + evtData := schemaEventPayload.GetData(context.Background(), kvMap, tc) shouldDelete, isOk := evtData[constants.DeleteColumnMarker] assert.True(t, isOk) assert.True(t, shouldDelete.(bool)) @@ -76,7 +77,7 @@ func TestGetDataTestDelete(t *testing.T) { assert.Equal(t, evtData[tc.IdempotentKey], now.Format(time.RFC3339)) tc.IdempotentKey = "" - evtData = schemaEventPayload.GetData(context.Background(), "pk", 1, tc) + evtData = schemaEventPayload.GetData(context.Background(), kvMap, tc) _, isOk = evtData[tc.IdempotentKey] assert.False(t, isOk, evtData) } @@ -109,7 +110,9 @@ func TestGetDataTestUpdate(t *testing.T) { }, } - evtData := schemaEventPayload.GetData(context.Background(), "pk", 1, &tc) + kvMap := map[string]interface{}{"pk": 1} + + evtData := schemaEventPayload.GetData(context.Background(), kvMap, &tc) assert.Equal(t, len(after), len(evtData), "has deletion flag") deletionFlag, isOk := evtData[constants.DeleteColumnMarker] diff --git a/lib/debezium/keys.go b/lib/debezium/keys.go new file mode 100644 index 000000000..dd43ba8f6 --- /dev/null +++ b/lib/debezium/keys.go @@ -0,0 +1,90 @@ +package debezium + +import ( + "encoding/json" + "fmt" + "strings" +) + +const ( + KeyFormatJSON = "org.apache.kafka.connect.json.JsonConverter" + KeyFormatString = "org.apache.kafka.connect.storage.StringConverter" + + stringPrefix = "Struct{" + stringSuffix = "}" +) + +func ParsePartitionKey(key []byte, cdcKeyFormat string) (map[string]interface{}, error) { + switch cdcKeyFormat { + case KeyFormatJSON: + return parsePartitionKeyStruct(key) + case KeyFormatString: + return parsePartitionKeyString(key) + + } + return nil, fmt.Errorf("format: %s is not supported", cdcKeyFormat) +} + +// parsePartitionKeyString is used to parse the partition key when it is getting emitted in the string format. +// This is not the recommended approach because through serializing a Struct into a string notation, the operation is buggy and potentially irreversible. +// Kafka's string serialization will emit the message to look like: Struct{k=v,k1=v1} +// However, if the k or v has `,` or `=` within it, it is not escaped and thus difficult to delineate between a separator or a continuation of the column or value. +// In the case where there are multiple `=`, we will use the first one to separate between the key and value. +// TL;DR - Use `org.apache.kafka.connect.json.JsonConverter` over `org.apache.kafka.connect.storage.StringConverter` +func parsePartitionKeyString(key []byte) (map[string]interface{}, error) { + // Key will look like key: Struct{quarter_id=1,course_id=course1,student_id=1} + if len(key) == 0 { + return nil, fmt.Errorf("key is nil") + } + + keyString := string(key) + if len(stringPrefix+stringSuffix) >= len(keyString) { + return nil, fmt.Errorf("key is too short") + } + + if !(strings.HasPrefix(keyString, stringPrefix) && strings.HasSuffix(keyString, stringSuffix)) { + return nil, fmt.Errorf("incorrect key structure") + } + + retMap := make(map[string]interface{}) + parsedKeyString := keyString[len(stringPrefix) : len(keyString)-1] + for _, kvPartString := range strings.Split(parsedKeyString, ",") { + kvParts := strings.Split(kvPartString, "=") + if len(kvParts) < 2 { + return nil, fmt.Errorf("malformed key value pair: %s", kvPartString) + } + + retMap[kvParts[0]] = strings.Join(kvParts[1:], "=") + } + + return retMap, nil +} + +func parsePartitionKeyStruct(key []byte) (map[string]interface{}, error) { + if len(key) == 0 { + return nil, fmt.Errorf("key is nil") + } + + var pkStruct map[string]interface{} + err := json.Unmarshal(key, &pkStruct) + if err != nil { + return nil, fmt.Errorf("failed to json unmarshal, error: %v", err) + } + + if len(pkStruct) == 0 { + return nil, fmt.Errorf("key is nil") + } + + _, isOk := pkStruct["payload"] + if !isOk { + // pkStruct does not have schema enabled + return pkStruct, nil + } + + pkStruct, isOk = pkStruct["payload"].(map[string]interface{}) + if !isOk { + return nil, fmt.Errorf("key object is malformated") + } + + return pkStruct, nil +} diff --git a/lib/debezium/keys_test.go b/lib/debezium/keys_test.go new file mode 100644 index 000000000..2c744f8b3 --- /dev/null +++ b/lib/debezium/keys_test.go @@ -0,0 +1,114 @@ +package debezium + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParsePartitionKeyString(t *testing.T) { + kv, err := parsePartitionKeyString([]byte("Struct{hi=world,foo=bar}")) + assert.NoError(t, err) + assert.Equal(t, kv["hi"], "world") + assert.Equal(t, kv["foo"], "bar") + + kv, err = parsePartitionKeyString([]byte("Struct{hi==world}")) + assert.NoError(t, err) + assert.Equal(t, kv["hi"], "=world") + + kv, err = parsePartitionKeyString([]byte("Struct{id=47}")) + assert.NoError(t, err) + assert.Equal(t, kv["id"], "47") + + kv, err = parsePartitionKeyString([]byte("Struct{uuid=d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c}")) + assert.Nil(t, err) + assert.Equal(t, kv["uuid"], "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c") + + badDataCases := []string{ + "", + "Struct{", + "Struct{}", + "}", + "Struct{uuid=a,,}", + "Struct{,,}", + } + + for _, badData := range badDataCases { + _, err = parsePartitionKeyString([]byte(badData)) + assert.Error(t, err) + } +} + +func TestParsePartitionKeyStruct(t *testing.T) { + badDataCases := []string{ + "", + "{}", + "{id:", + `{"id":`, + } + + for _, badData := range badDataCases { + _, err := parsePartitionKeyStruct([]byte(badData)) + assert.Error(t, err, badData) + } + + kv, err := parsePartitionKeyStruct([]byte(`{"id": 47}`)) + assert.Nil(t, err) + assert.Equal(t, kv["id"], float64(47)) + + kv, err = parsePartitionKeyStruct([]byte(`{"uuid": "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c"}`)) + assert.Nil(t, err) + assert.Equal(t, kv["uuid"], "d4a5bc26-9ae6-4dd4-8894-39cbcd2d526c") + + kv, err = parsePartitionKeyStruct([]byte(`{ + "schema": { + "type": "struct", + "fields": [{ + "type": "int32", + "optional": false, + "default": 0, + "field": "id" + }], + "optional": false, + "name": "dbserver1.inventory.customers.Key" + }, + "payload": { + "id": 1002 + } +}`)) + + assert.NoError(t, err) + assert.Equal(t, kv["id"], float64(1002)) + + // Composite key + compositeKeyString := `{ + "schema": { + "type": "struct", + "fields": [{ + "type": "int32", + "optional": false, + "field": "quarter_id" + }, { + "type": "string", + "optional": false, + "field": "course_id" + }, { + "type": "int32", + "optional": false, + "field": "student_id" + }], + "optional": false, + "name": "dbserver1.inventory.course_grades.Key" + }, + "payload": { + "quarter_id": 1, + "course_id": "course1", + "student_id": 1 + } +}` + + kv, err = parsePartitionKeyStruct([]byte(compositeKeyString)) + assert.NoError(t, err) + assert.Equal(t, kv["quarter_id"], float64(1)) + assert.Equal(t, kv["student_id"], float64(1)) + assert.Equal(t, kv["course_id"], "course1") +} diff --git a/lib/dwh/dml/merge.go b/lib/dwh/dml/merge.go index 67061cb1d..eb6e7ae11 100644 --- a/lib/dwh/dml/merge.go +++ b/lib/dwh/dml/merge.go @@ -3,13 +3,29 @@ package dml import ( "errors" "fmt" + "github.com/artie-labs/transfer/lib/typing" "strings" "github.com/artie-labs/transfer/lib/array" "github.com/artie-labs/transfer/lib/config/constants" ) -func MergeStatement(fqTableName, subQuery, pk, idempotentKey string, cols []string, softDelete bool, specialStructCastMergeKey bool) (string, error) { +type MergeArgument struct { + // TODO: Test + FqTableName string + SubQuery string + IdempotentKey string + PrimaryKeys []string + Columns []string + ColumnToType map[string]typing.KindDetails + + // SpecialCastingRequired - This is used for columns that have JSON value. This is required for BigQuery + // We will be casting the value in this column as such: `TO_JSON_STRING()` + SpecialCastingRequired bool + SoftDelete bool +} + +func MergeStatement(m MergeArgument) (string, error) { // We should not need idempotency key for DELETE // This is based on the assumption that the primary key would be atomically increasing or UUID based // With AI, the sequence will increment (never decrement). And UUID is there to prevent universal hash collision @@ -18,17 +34,27 @@ func MergeStatement(fqTableName, subQuery, pk, idempotentKey string, cols []stri // We also need to do staged table's idempotency key is GTE target table's idempotency key // This is because Snowflake does not respect NS granularity. var idempotentClause string - if idempotentKey != "" { - idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", idempotentKey, idempotentKey) + if m.IdempotentKey != "" { + idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", m.IdempotentKey, m.IdempotentKey) } - equalitySQL := fmt.Sprintf("c.%s = cc.%s", pk, pk) - if specialStructCastMergeKey { - // BigQuery requires special casting to compare two JSON objects. - equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", pk, pk) + var equalitySQLParts []string + for _, primaryKey := range m.PrimaryKeys { + equalitySQL := fmt.Sprintf("c.%s = cc.%s", primaryKey, primaryKey) + typeKind, isOk := m.ColumnToType[primaryKey] + if !isOk { + return "", fmt.Errorf("error: column: %s does not exist in columnToType: %v", primaryKey, m.ColumnToType) + } + + if typeKind.Kind == typing.Struct.Kind { + // BigQuery requires special casting to compare two JSON objects. + equalitySQL = fmt.Sprintf("TO_JSON_STRING(c.%s) = TO_JSON_STRING(cc.%s)", primaryKey, primaryKey) + } + + equalitySQLParts = append(equalitySQLParts, equalitySQL) } - if softDelete { + if m.SoftDelete { return fmt.Sprintf(` MERGE INTO %s c using (%s) as cc on %s when matched %sthen UPDATE @@ -41,19 +67,19 @@ func MergeStatement(fqTableName, subQuery, pk, idempotentKey string, cols []stri ( %s ); - `, fqTableName, subQuery, equalitySQL, + `, m.FqTableName, m.SubQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, array.ColumnsUpdateQuery(cols, "cc"), + idempotentClause, array.ColumnsUpdateQuery(m.Columns, "cc"), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), - array.StringsJoinAddPrefix(cols, ",", "cc.")), nil + constants.DeleteColumnMarker, strings.Join(m.Columns, ","), + array.StringsJoinAddPrefix(m.Columns, ",", "cc.")), nil } // We also need to remove __artie flags since it does not exist in the destination table var removed bool - for idx, col := range cols { + for idx, col := range m.Columns { if col == constants.DeleteColumnMarker { - cols = append(cols[:idx], cols[idx+1:]...) + m.Columns = append(m.Columns[:idx], m.Columns[idx+1:]...) removed = true break } @@ -76,13 +102,13 @@ func MergeStatement(fqTableName, subQuery, pk, idempotentKey string, cols []stri ( %s ); - `, fqTableName, subQuery, equalitySQL, + `, m.FqTableName, m.SubQuery, strings.Join(equalitySQLParts, " and "), // Delete constants.DeleteColumnMarker, // Update - constants.DeleteColumnMarker, idempotentClause, array.ColumnsUpdateQuery(cols, "cc"), + constants.DeleteColumnMarker, idempotentClause, array.ColumnsUpdateQuery(m.Columns, "cc"), // Insert - constants.DeleteColumnMarker, strings.Join(cols, ","), - array.StringsJoinAddPrefix(cols, ",", "cc.")), nil + constants.DeleteColumnMarker, strings.Join(m.Columns, ","), + array.StringsJoinAddPrefix(m.Columns, ",", "cc.")), nil } diff --git a/lib/dwh/dml/merge_test.go b/lib/dwh/dml/merge_test.go index 21f599275..f7c5ccb7d 100644 --- a/lib/dwh/dml/merge_test.go +++ b/lib/dwh/dml/merge_test.go @@ -3,6 +3,7 @@ package dml import ( "fmt" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" "strings" "testing" @@ -30,7 +31,16 @@ func TestMergeStatementSoftDelete(t *testing.T) { strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) for _, idempotentKey := range []string{"", "updated_at"} { - mergeSQL, err := MergeStatement(fqTable, subQuery, "id", idempotentKey, cols, true, false) + mergeSQL, err := MergeStatement(MergeArgument{ + FqTableName: fqTable, + SubQuery: subQuery, + IdempotentKey: idempotentKey, + PrimaryKeys: []string{"id"}, + Columns: cols, + ColumnToType: map[string]typing.KindDetails{"id": typing.String}, + SpecialCastingRequired: false, + SoftDelete: true, + }) assert.NoError(t, err) assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) // Soft deletion flag being passed. @@ -60,8 +70,16 @@ func TestMergeStatement(t *testing.T) { // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - - mergeSQL, err := MergeStatement(fqTable, subQuery, "id", "", cols, false, false) + mergeSQL, err := MergeStatement(MergeArgument{ + FqTableName: fqTable, + SubQuery: subQuery, + IdempotentKey: "", + PrimaryKeys: []string{"id"}, + Columns: cols, + ColumnToType: map[string]typing.KindDetails{"id": typing.String}, + SpecialCastingRequired: false, + SoftDelete: false, + }) assert.NoError(t, err) assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) assert.False(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) @@ -86,8 +104,54 @@ func TestMergeStatementIdempotentKey(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - mergeSQL, err := MergeStatement(fqTable, subQuery, "id", "updated_at", cols, false, false) + mergeSQL, err := MergeStatement(MergeArgument{ + FqTableName: fqTable, + SubQuery: subQuery, + IdempotentKey: "updated_at", + PrimaryKeys: []string{"id"}, + Columns: cols, + ColumnToType: map[string]typing.KindDetails{"id": typing.String}, + SpecialCastingRequired: false, + SoftDelete: false, + }) + assert.NoError(t, err) + assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) + assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) +} + +func TestMergeStatementCompositeKey(t *testing.T) { + fqTable := "database.schema.table" + cols := []string{ + "id", + "another_id", + "bar", + "updated_at", + constants.DeleteColumnMarker, + } + + tableValues := []string{ + fmt.Sprintf("('%s', '%s', '%s', '%v', false)", "1", "3", "456", time.Now().Round(0).UTC()), + fmt.Sprintf("('%s', '%s', '%s', '%v', false)", "2", "2", "bb", time.Now().Round(0).UTC()), + fmt.Sprintf("('%s', '%s', '%s', '%v', false)", "3", "1", "dd", time.Now().Round(0).UTC()), + } + + // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); + subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", + strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) + + mergeSQL, err := MergeStatement(MergeArgument{ + FqTableName: fqTable, + SubQuery: subQuery, + IdempotentKey: "updated_at", + PrimaryKeys: []string{"id", "another_id"}, + Columns: cols, + ColumnToType: map[string]typing.KindDetails{"id": typing.String, "another_id": typing.String}, + SpecialCastingRequired: false, + SoftDelete: false, + }) assert.NoError(t, err) assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable)), mergeSQL) assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at")), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + + assert.True(t, strings.Contains(mergeSQL, fmt.Sprintf("cc on c.id = cc.id and c.another_id = cc.another_id"))) } diff --git a/lib/optimization/event.go b/lib/optimization/event.go index 154bcc79c..18de346d6 100644 --- a/lib/optimization/event.go +++ b/lib/optimization/event.go @@ -13,7 +13,7 @@ import ( type TableData struct { InMemoryColumns map[string]typing.KindDetails // list of columns RowsData map[string]map[string]interface{} // pk -> { col -> val } - PrimaryKey string + PrimaryKeys []string kafkalib.TopicConfig // Partition to the latest offset(s). diff --git a/models/event.go b/models/event.go index efb7d9b18..4aa8f4647 100644 --- a/models/event.go +++ b/models/event.go @@ -2,7 +2,9 @@ package models import ( "context" + "fmt" "github.com/artie-labs/transfer/lib/config/constants" + "sort" "time" "github.com/artie-labs/transfer/lib/array" @@ -11,30 +13,32 @@ import ( ) type Event struct { - Table string - PrimaryKeyName string - PrimaryKeyValue interface{} - Data map[string]interface{} // json serialized column data - ExecutionTime time.Time // When the SQL command was executed + Table string + PrimaryKeyMap map[string]interface{} + Data map[string]interface{} // json serialized column data + ExecutionTime time.Time // When the SQL command was executed } -func ToMemoryEvent(ctx context.Context, event cdc.Event, pkName string, pkValue interface{}, tc *kafkalib.TopicConfig) Event { +func ToMemoryEvent(ctx context.Context, event cdc.Event, pkMap map[string]interface{}, tc *kafkalib.TopicConfig) Event { return Event{ - Table: tc.TableName, - PrimaryKeyName: pkName, - PrimaryKeyValue: pkValue, - ExecutionTime: event.GetExecutionTime(), - Data: event.GetData(ctx, pkName, pkValue, tc), + Table: tc.TableName, + PrimaryKeyMap: pkMap, + ExecutionTime: event.GetExecutionTime(), + Data: event.GetData(ctx, pkMap, tc), } } func (e *Event) IsValid() bool { // Does it have a PK or table set? - if array.Empty([]string{e.Table, e.PrimaryKeyName}) { + if array.Empty([]string{e.Table}) { return false } - if e.PrimaryKeyValue == nil { + if len(e.PrimaryKeyMap) == 0 { + return false + } + + if len(e.Data) == 0 { return false } @@ -46,3 +50,26 @@ func (e *Event) IsValid() bool { return true } + +// PrimaryKeys is returned in a sorted manner to be safe. +// We use PrimaryKeyValue() as our internal identifier within our db +// It is critical to make sure `PrimaryKeyValue()` is a deterministic call. +func (e *Event) PrimaryKeys() []string { + var keys []string + for key := range e.PrimaryKeyMap { + keys = append(keys, key) + } + + sort.Strings(keys) + return keys +} + +// PrimaryKeyValue - as per above, this needs to return a deterministic k/v string. +func (e *Event) PrimaryKeyValue() string { + var key string + for _, pk := range e.PrimaryKeys() { + key += fmt.Sprintf("%s=%v", pk, e.PrimaryKeyMap[pk]) + } + + return key +} diff --git a/models/event_test.go b/models/event_test.go index fad4e1ef9..49882b232 100644 --- a/models/event_test.go +++ b/models/event_test.go @@ -10,11 +10,15 @@ import ( type fakeEvent struct{} +var idMap = map[string]interface{}{ + "id": 123, +} + func (f fakeEvent) GetExecutionTime() time.Time { return time.Now() } -func (f fakeEvent) GetData(ctx context.Context, pkName string, pkVal interface{}, config *kafkalib.TopicConfig) map[string]interface{} { +func (f fakeEvent) GetData(ctx context.Context, pkMap map[string]interface{}, config *kafkalib.TopicConfig) map[string]interface{} { return map[string]interface{}{constants.DeleteColumnMarker: false} } @@ -25,10 +29,7 @@ func (m *ModelsTestSuite) TestEvent_IsValid() { e.Table = "foo" assert.False(m.T(), e.IsValid()) - e.PrimaryKeyName = "id" - assert.False(m.T(), e.IsValid()) - - e.PrimaryKeyValue = 123 + e.PrimaryKeyMap = idMap assert.False(m.T(), e.IsValid()) e.Data = make(map[string]interface{}) @@ -38,9 +39,78 @@ func (m *ModelsTestSuite) TestEvent_IsValid() { func (m *ModelsTestSuite) TestEvent_TableName() { var f fakeEvent - evt := ToMemoryEvent(context.Background(), f, "id", "123", &kafkalib.TopicConfig{ + evt := ToMemoryEvent(context.Background(), f, idMap, &kafkalib.TopicConfig{ TableName: "orders", }) assert.Equal(m.T(), "orders", evt.Table) } + +func (m *ModelsTestSuite) TestEventPrimaryKeys() { + evt := &Event{ + Table: "foo", + PrimaryKeyMap: map[string]interface{}{ + "id": true, + "id1": true, + "id2": true, + "id3": true, + "id4": true, + }, + } + + requiredKeys := []string{"id", "id1", "id2", "id3", "id4"} + for _, requiredKey := range requiredKeys { + var found bool + for _, primaryKey := range evt.PrimaryKeys() { + found = requiredKey == primaryKey + if found { + break + } + } + + assert.True(m.T(), found, requiredKey) + } + + anotherEvt := &Event{ + Table: "foo", + PrimaryKeyMap: map[string]interface{}{ + "id": 1, + "course_id": 2, + }, + } + + var found bool + possibilities := []string{"course_id=2id=1"} + pkVal := anotherEvt.PrimaryKeyValue() + for _, possibility := range possibilities { + if found = possibility == pkVal; found { + break + } + } + + assert.True(m.T(), found, anotherEvt.PrimaryKeyValue()) + + // Make sure the ordering for the pk is deterministic. + partsMap := make(map[string]bool) + for i := 0; i < 100; i++ { + partsMap[anotherEvt.PrimaryKeyValue()] = true + } + + assert.Equal(m.T(), len(partsMap), 1) +} + +func (m *ModelsTestSuite) TestPrimaryKeyValueDeterministic() { + evt := &Event{ + PrimaryKeyMap: map[string]interface{}{ + "aa": 1, + "bb": 5, + "zz": "ff", + "gg": "artie", + "dusty": "mini aussie", + }, + } + + for i := 0; i < 500*1000; i++ { + assert.Equal(m.T(), evt.PrimaryKeyValue(), "aa=1bb=5dusty=mini aussiegg=artiezz=ff") + } +} diff --git a/models/flush/flush_test.go b/models/flush/flush_test.go index 4905fb67e..19948434c 100644 --- a/models/flush/flush_test.go +++ b/models/flush/flush_test.go @@ -23,9 +23,10 @@ var topicConfig = &kafkalib.TopicConfig{ func (f *FlushTestSuite) TestMemoryBasic() { for i := 0; i < 5; i++ { event := models.Event{ - Table: "foo", - PrimaryKeyName: "id", - PrimaryKeyValue: fmt.Sprintf("pk-%d", i), + Table: "foo", + PrimaryKeyMap: map[string]interface{}{ + "id": fmt.Sprintf("pk-%d", i), + }, Data: map[string]interface{}{ constants.DeleteColumnMarker: true, "abc": "def", @@ -46,9 +47,10 @@ func (f *FlushTestSuite) TestShouldFlush() { for i := 0; i < int(float64(cfg.Config.BufferRows)*1.5); i++ { event := models.Event{ - Table: "postgres", - PrimaryKeyName: "id", - PrimaryKeyValue: fmt.Sprintf("pk-%d", i), + Table: "postgres", + PrimaryKeyMap: map[string]interface{}{ + "id": fmt.Sprintf("pk-%d", i), + }, Data: map[string]interface{}{ constants.DeleteColumnMarker: true, "pk": fmt.Sprintf("pk-%d", i), @@ -81,10 +83,12 @@ func (f *FlushTestSuite) TestMemoryConcurrency() { defer wg.Done() for i := 0; i < 5; i++ { event := models.Event{ - Table: tableName, - PrimaryKeyName: "id", - PrimaryKeyValue: fmt.Sprintf("pk-%d", i), + Table: tableName, + PrimaryKeyMap: map[string]interface{}{ + "id": fmt.Sprintf("pk-%d", i), + }, Data: map[string]interface{}{ + "id": fmt.Sprintf("pk-%d", i), constants.DeleteColumnMarker: true, "pk": fmt.Sprintf("pk-%d", i), "foo": "bar", diff --git a/models/memory.go b/models/memory.go index 696e38ab6..323dba67b 100644 --- a/models/memory.go +++ b/models/memory.go @@ -3,7 +3,6 @@ package models import ( "context" "errors" - "fmt" "github.com/artie-labs/transfer/lib/artie" "github.com/artie-labs/transfer/lib/config" "github.com/artie-labs/transfer/lib/kafkalib" @@ -54,14 +53,14 @@ func (e *Event) Save(ctx context.Context, topicConfig *kafkalib.TopicConfig, mes inMemoryDB.TableData[e.Table] = &optimization.TableData{ RowsData: map[string]map[string]interface{}{}, InMemoryColumns: map[string]typing.KindDetails{}, - PrimaryKey: e.PrimaryKeyName, + PrimaryKeys: e.PrimaryKeys(), TopicConfig: *topicConfig, PartitionsToLastMessage: map[string][]artie.Message{}, } } // Update the key, offset and TS - inMemoryDB.TableData[e.Table].RowsData[fmt.Sprint(e.PrimaryKeyValue)] = e.Data + inMemoryDB.TableData[e.Table].RowsData[e.PrimaryKeyValue()] = e.Data // If the message is Kafka, then we only need the latest one // If it's pubsub, we will store all of them in memory. This is because GCP pub/sub REQUIRES us to ack every single message @@ -76,10 +75,12 @@ func (e *Event) Save(ctx context.Context, topicConfig *kafkalib.TopicConfig, mes // Increment row count inMemoryDB.TableData[e.Table].Rows += 1 + // TODO: Test. // Update col if necessary + sanitizedData := make(map[string]interface{}) for col, val := range e.Data { // columns need to all be normalized and lower cased. - col = strings.ToLower(col) + newColName := strings.ToLower(col) // Columns here could contain spaces. Every destination treats spaces in a column differently. // So far, Snowflake accepts them when escaped properly, however BigQuery does not accept it. @@ -89,7 +90,7 @@ func (e *Event) Save(ctx context.Context, topicConfig *kafkalib.TopicConfig, mes containsSpace, col = stringutil.EscapeSpaces(col) if containsSpace { // Write the message back if the column has changed. - e.Data[col] = val + sanitizedData[newColName] = val } if val == "__debezium_unavailable_value" { @@ -97,8 +98,6 @@ func (e *Event) Save(ctx context.Context, topicConfig *kafkalib.TopicConfig, mes // TL;DR - Sometimes a column that is unchanged within a DML will not be emitted // DBZ has stubbed it out by providing this value, so we will skip it when we see it. // See: https://issues.redhat.com/browse/DBZ-4276 - delete(e.Data, col) - // We are directly adding this column to our in-memory database // This ensures that this column exists, we just have an invalid value (so we will not replicate over). // However, this will ensure that we do not drop the column within the destination @@ -120,8 +119,13 @@ func (e *Event) Save(ctx context.Context, topicConfig *kafkalib.TopicConfig, mes } } } + + sanitizedData[newColName] = val } + // Swap out sanitizedData <> data. + e.Data = sanitizedData + settings := config.FromContext(ctx) return inMemoryDB.TableData[e.Table].Rows > settings.Config.BufferRows, nil } diff --git a/models/memory_test.go b/models/memory_test.go index 4827e6218..511c16d59 100644 --- a/models/memory_test.go +++ b/models/memory_test.go @@ -23,8 +23,10 @@ func (m *ModelsTestSuite) SaveEvent() { anotherLowerCol := "dusty the mini aussie" event := Event{ - Table: "foo", - PrimaryKeyValue: "123", + Table: "foo", + PrimaryKeyMap: map[string]interface{}{ + "id": "123", + }, Data: map[string]interface{}{ constants.DeleteColumnMarker: true, expectedCol: "dusty", @@ -53,8 +55,10 @@ func (m *ModelsTestSuite) SaveEvent() { badColumn := "other" edgeCaseEvent := Event{ - Table: "foo", - PrimaryKeyValue: "12344", + Table: "foo", + PrimaryKeyMap: map[string]interface{}{ + "id": "12344", + }, Data: map[string]interface{}{ constants.DeleteColumnMarker: true, expectedCol: "dusty", @@ -69,4 +73,32 @@ func (m *ModelsTestSuite) SaveEvent() { val, isOk := GetMemoryDB().TableData["foo"].InMemoryColumns[badColumn] assert.True(m.T(), isOk) assert.Equal(m.T(), val, typing.Invalid) + + assert.False(m.T(), true) } + +//func (m *ModelsTestSuite) TestEvent_SaveCasing() { +// assert.True(m.T(), false) +// +// event := Event{ +// Table: "foo", +// PrimaryKeyMap: map[string]interface{}{ +// "id": "123", +// }, +// Data: map[string]interface{}{ +// constants.DeleteColumnMarker: true, +// "randomCol": "dusty", +// "anotherCOL": 13.37, +// }, +// } +// +// kafkaMsg := kafka.Message{} +// _, err := event.Save(m.ctx, topicConfig, artie.NewMessage(&kafkaMsg, nil, kafkaMsg.Topic)) +// +// fmt.Println("inMemoryDB", inMemoryDB.TableData["foo"].RowsData) +// fmt.Println("here", event.Data) +// +// assert.False(m.T(), true) +// +// assert.Nil(m.T(), err) +//} diff --git a/processes/consumer/process.go b/processes/consumer/process.go index 097875c31..7d1d400b8 100644 --- a/processes/consumer/process.go +++ b/processes/consumer/process.go @@ -30,7 +30,7 @@ func processMessage(ctx context.Context, msg artie.Message, topicToConfigFmtMap tags["schema"] = topicConfig.tc.Schema tags["table"] = topicConfig.tc.TableName - pkName, pkValue, err := topicConfig.GetPrimaryKey(ctx, msg.Key(), topicConfig.tc) + pkMap, err := topicConfig.GetPrimaryKey(ctx, msg.Key(), topicConfig.tc) if err != nil { tags["what"] = "marshall_pk_err" return false, fmt.Errorf("cannot unmarshall key, key: %s, err: %v", string(msg.Key()), err) @@ -42,7 +42,7 @@ func processMessage(ctx context.Context, msg artie.Message, topicToConfigFmtMap return false, fmt.Errorf("cannot unmarshall event, err: %v", err) } - evt := models.ToMemoryEvent(ctx, event, pkName, pkValue, topicConfig.tc) + evt := models.ToMemoryEvent(ctx, event, pkMap, topicConfig.tc) shouldFlush, err = evt.Save(ctx, topicConfig.tc, msg) if err != nil { tags["what"] = "save_fail" diff --git a/processes/consumer/process_test.go b/processes/consumer/process_test.go index ccbc3dc7c..73d968d87 100644 --- a/processes/consumer/process_test.go +++ b/processes/consumer/process_test.go @@ -124,12 +124,12 @@ func TestProcessMessageFailures(t *testing.T) { } // Tombstone means deletion - val, isOk := memoryDB.TableData[table].RowsData["1"][constants.DeleteColumnMarker] + val, isOk := memoryDB.TableData[table].RowsData["id=1"][constants.DeleteColumnMarker] assert.True(t, isOk) assert.True(t, val.(bool)) // Non tombstone = no delete. - val, isOk = memoryDB.TableData[table].RowsData["2"][constants.DeleteColumnMarker] + val, isOk = memoryDB.TableData[table].RowsData["id=2"][constants.DeleteColumnMarker] assert.True(t, isOk) assert.False(t, val.(bool))