diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..8e80fe01 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @artie-labs/engineering diff --git a/.github/workflows/gha-go-test.yaml b/.github/workflows/gha-go-test.yaml index 7672158d..e52f5a40 100644 --- a/.github/workflows/gha-go-test.yaml +++ b/.github/workflows/gha-go-test.yaml @@ -10,6 +10,8 @@ jobs: uses: actions/setup-go@v5 with: go-version: 1.23 + - name: Run vet + run: make vet - name: Run staticcheck env: SC_VERSION: "2024.1" diff --git a/.goreleaser.yaml b/.goreleaser.yaml index d02e2f46..676256d6 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -60,7 +60,7 @@ archives: checksum: name_template: 'checksums.txt' snapshot: - name_template: "{{ incpatch .Version }}-next" + version_template: "{{ incpatch .Version }}-next" changelog: sort: asc filters: diff --git a/Makefile b/Makefile index d5599626..78c99e80 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,7 @@ +.PHONY: vet +vet: + go vet ./... + .PHONY: static static: staticcheck ./... @@ -39,3 +43,9 @@ generate: go get github.com/maxbrunsfeld/counterfeiter/v6 go generate ./... go mod tidy + +.PHONY: upgrade +upgrade: + go get github.com/artie-labs/transfer + go mod tidy + echo "Upgrade complete" diff --git a/config/mssql.go b/config/mssql.go index e39e2539..aa88ef5d 100644 --- a/config/mssql.go +++ b/config/mssql.go @@ -29,6 +29,8 @@ type MSSQLTable struct { OptionalPrimaryKeyValStart string `yaml:"optionalPrimaryKeyValStart,omitempty"` OptionalPrimaryKeyValEnd string `yaml:"optionalPrimaryKeyValEnd,omitempty"` ExcludeColumns []string `yaml:"excludeColumns,omitempty"` + // IncludeColumns - List of columns that should be included in the change event record. + IncludeColumns []string `yaml:"includeColumns,omitempty"` } func (m *MSSQL) ToDSN() string { @@ -95,6 +97,11 @@ func (m *MSSQL) Validate() error { if stringutil.Empty(table.Name, table.Schema) { return fmt.Errorf("table name and schema must be passed in") } + + // You should not be able to filter and exclude columns at the same time + if len(table.ExcludeColumns) > 0 && len(table.IncludeColumns) > 0 { + return fmt.Errorf("cannot exclude and include columns at the same time") + } } return nil diff --git a/config/mssql_test.go b/config/mssql_test.go index 54decdfa..3704169e 100644 --- a/config/mssql_test.go +++ b/config/mssql_test.go @@ -129,6 +129,26 @@ func TestMSSQL_Validate(t *testing.T) { assert.ErrorContains(t, m.Validate(), "table name and schema must be passed in") } + { + // Exclude and include columns at the same time + m := &MSSQL{ + Host: "host", + Port: 1, + Username: "username", + Password: "password", + Database: "database", + Tables: []*MSSQLTable{ + { + Name: "name", + Schema: "schema", + IncludeColumns: []string{"foo"}, + ExcludeColumns: []string{"bar"}, + }, + }, + } + + assert.ErrorContains(t, m.Validate(), "cannot exclude and include columns at the same time") + } { // Valid m := &MSSQL{ diff --git a/config/mysql.go b/config/mysql.go index 11f67b31..fede9503 100644 --- a/config/mysql.go +++ b/config/mysql.go @@ -39,6 +39,8 @@ type MySQLTable struct { OptionalPrimaryKeyValStart string `yaml:"optionalPrimaryKeyValStart,omitempty"` OptionalPrimaryKeyValEnd string `yaml:"optionalPrimaryKeyValEnd,omitempty"` ExcludeColumns []string `yaml:"excludeColumns,omitempty"` + // IncludeColumns - List of columns that should be included in the change event record. + IncludeColumns []string `yaml:"includeColumns,omitempty"` } func (m *MySQLTable) GetBatchSize() uint { @@ -91,6 +93,11 @@ func (m *MySQL) Validate() error { if table.Name == "" { return fmt.Errorf("table name must be passed in") } + + // You should not be able to filter and exclude columns at the same time + if len(table.ExcludeColumns) > 0 && len(table.IncludeColumns) > 0 { + return fmt.Errorf("cannot exclude and include columns at the same time") + } } return nil diff --git a/config/mysql_test.go b/config/mysql_test.go index 7bcce090..75372b5c 100644 --- a/config/mysql_test.go +++ b/config/mysql_test.go @@ -96,6 +96,17 @@ func TestMySQL_Validate(t *testing.T) { c.Tables = append(c.Tables, &MySQLTable{}) assert.ErrorContains(t, c.Validate(), "table name must be passed in") } + { + // exclude and include at the same time + c := createValidConfig() + c.Tables = append(c.Tables, &MySQLTable{ + Name: "foo", + IncludeColumns: []string{"foo"}, + ExcludeColumns: []string{"bar"}, + }) + + assert.ErrorContains(t, c.Validate(), "cannot exclude and include columns at the same time") + } } func TestMySQL_ToDSN(t *testing.T) { diff --git a/config/postgres.go b/config/postgres.go index 6ec0f470..f752d7c0 100644 --- a/config/postgres.go +++ b/config/postgres.go @@ -39,9 +39,12 @@ type PostgreSQLTable struct { // Optional settings BatchSize uint `yaml:"batchSize,omitempty"` + PrimaryKeysOverride []string `yaml:"primaryKeysOverride,omitempty"` OptionalPrimaryKeyValStart string `yaml:"optionalPrimaryKeyValStart,omitempty"` OptionalPrimaryKeyValEnd string `yaml:"optionalPrimaryKeyValEnd,omitempty"` ExcludeColumns []string `yaml:"excludeColumns,omitempty"` + // IncludeColumns - List of columns that should be included in the change event record. + IncludeColumns []string `yaml:"includeColumns,omitempty"` } func (p *PostgreSQLTable) GetBatchSize() uint { @@ -98,6 +101,11 @@ func (p *PostgreSQL) Validate() error { if table.Schema == "" { return fmt.Errorf("schema must be passed in") } + + // You should not be able to filter and exclude columns at the same time + if len(table.ExcludeColumns) > 0 && len(table.IncludeColumns) > 0 { + return fmt.Errorf("cannot exclude and include columns at the same time") + } } return nil diff --git a/config/postgres_test.go b/config/postgres_test.go index 55eeacc5..e5c42fce 100644 --- a/config/postgres_test.go +++ b/config/postgres_test.go @@ -117,6 +117,26 @@ func TestPostgreSQL_Validate(t *testing.T) { assert.ErrorContains(t, p.Validate(), "schema must be passed in") } + { + // Filtering and excluding at the same time + p := &PostgreSQL{ + Host: "host", + Port: 1, + Username: "username", + Password: "password", + Database: "database", + Tables: []*PostgreSQLTable{ + { + Name: "name", + Schema: "schema", + ExcludeColumns: []string{"a"}, + IncludeColumns: []string{"b"}, + }, + }, + } + + assert.ErrorContains(t, p.Validate(), "cannot exclude and include columns at the same time") + } { // Valid p := &PostgreSQL{ diff --git a/go.mod b/go.mod index 1ed1e8ee..7dabe01a 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/DataDog/datadog-go/v5 v5.5.0 - github.com/artie-labs/transfer v1.26.8 + github.com/artie-labs/transfer v1.27.11 github.com/aws/aws-sdk-go-v2 v1.30.3 github.com/aws/aws-sdk-go-v2/config v1.27.27 github.com/aws/aws-sdk-go-v2/credentials v1.17.27 diff --git a/go.sum b/go.sum index 17dbc8d5..5a7ac86b 100644 --- a/go.sum +++ b/go.sum @@ -93,8 +93,8 @@ github.com/apache/thrift v0.0.0-20181112125854-24918abba929/go.mod h1:cp2SuWMxlE github.com/apache/thrift v0.14.2/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/apache/thrift v0.17.0 h1:cMd2aj52n+8VoAtvSvLn4kDC3aZ6IAkBuqWQ2IDu7wo= github.com/apache/thrift v0.17.0/go.mod h1:OLxhMRJxomX+1I/KUw03qoV3mMz16BwaKI+d4fPBx7Q= -github.com/artie-labs/transfer v1.26.8 h1:aNhd4f3KwHOl0NsCwS1c4SJfU+CGveleqQgMgCAZG/0= -github.com/artie-labs/transfer v1.26.8/go.mod h1:BlYxzzlXGHOMNSgbpcjzw1zQSD/wXmb93NoPBhOmcqA= +github.com/artie-labs/transfer v1.27.11 h1:2J5kV/q2RmB7PUJrqIMx+gL9Rp5zJS6Ey72koqSuEpk= +github.com/artie-labs/transfer v1.27.11/go.mod h1:+a/UhlQVRIpdz3muS1yhSvyX42RQL0LHOdovGZfEsDE= github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= github.com/aws/aws-sdk-go-v2 v1.16.12/go.mod h1:C+Ym0ag2LIghJbXhfXZ0YEEp49rBWowxKzJLUoob0ts= github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= diff --git a/integration_tests/mongo/main.go b/integration_tests/mongo/main.go index 8581705a..ed78eac0 100644 --- a/integration_tests/mongo/main.go +++ b/integration_tests/mongo/main.go @@ -4,6 +4,12 @@ import ( "context" "encoding/json" "fmt" + "log/slog" + "math/rand/v2" + "os" + "reflect" + "time" + "github.com/artie-labs/reader/config" "github.com/artie-labs/reader/integration_tests/utils" "github.com/artie-labs/reader/lib" @@ -16,11 +22,6 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" - "log/slog" - "math/rand/v2" - "os" - "reflect" - "time" ) func main() { @@ -96,25 +97,25 @@ func testTypes(ctx context.Context, db *mongo.Database, mongoCfg config.MongoDB) ts := time.Date(2020, 10, 5, 12, 0, 0, 0, time.UTC) _, err = collection.InsertOne(ctx, bson.D{ - {"_id", objId}, - {"string", "This is a string"}, - {"int32", int32(32)}, - {"int64", int64(64)}, - {"double", 3.14}, - {"bool", true}, - {"datetime", ts}, - {"embeddedDocument", bson.D{ - {"field1", "value1"}, - {"field2", "value2"}, + {Key: "_id", Value: objId}, + {Key: "string", Value: "This is a string"}, + {Key: "int32", Value: int32(32)}, + {Key: "int64", Value: int64(64)}, + {Key: "double", Value: 3.14}, + {Key: "bool", Value: true}, + {Key: "datetime", Value: ts}, + {Key: "embeddedDocument", Value: bson.D{ + {Key: "field1", Value: "value1"}, + {Key: "field2", Value: "value2"}, }}, - {"embeddedMap", bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}}, - {"array", bson.A{"item1", 2, true, 3.14}}, - {"binary", []byte("binary data")}, - {"objectId", objId}, - {"null", nil}, - {"timestamp", primitive.Timestamp{T: uint32(ts.Unix()), I: 1}}, - {"minKey", primitive.MinKey{}}, - {"maxKey", primitive.MaxKey{}}, + {Key: "embeddedMap", Value: bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}}, + {Key: "array", Value: bson.A{"item1", 2, true, 3.14}}, + {Key: "binary", Value: []byte("binary data")}, + {Key: "objectId", Value: objId}, + {Key: "null", Value: nil}, + {Key: "timestamp", Value: primitive.Timestamp{T: uint32(ts.Unix()), I: 1}}, + {Key: "minKey", Value: primitive.MinKey{}}, + {Key: "maxKey", Value: primitive.MaxKey{}}, }) if err != nil { return fmt.Errorf("failed to insert row: %w", err) @@ -169,12 +170,12 @@ func testTypes(ctx context.Context, db *mongo.Database, mongoCfg config.MongoDB) return fmt.Errorf("failed to get event from bytes: %w", err) } - pkMap, err := dbz.GetPrimaryKey(actualPkBytes, &kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) + pkMap, err := dbz.GetPrimaryKey(actualPkBytes, kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) if err != nil { return fmt.Errorf("failed to get primary key: %w", err) } - data, err := evt.GetData(pkMap, &kafkalib.TopicConfig{}) + data, err := evt.GetData(pkMap, kafkalib.TopicConfig{}) if err != nil { return fmt.Errorf("failed to get data: %w", err) } @@ -182,11 +183,11 @@ func testTypes(ctx context.Context, db *mongo.Database, mongoCfg config.MongoDB) expectedPayload := map[string]any{ "objectId": "66a95fae3776c2f21f0ff568", "array": []any{"item1", int32(2), true, 3.14}, - "datetime": ts.Format(ext.ISO8601), + "datetime": ext.NewExtendedTime(ts, ext.TimestampTzKindType, "2006-01-02T15:04:05.999-07:00"), "int64": int64(64), "__artie_delete": false, "__artie_only_set_delete": false, - "timestamp": ts.Format(ext.ISO8601), + "timestamp": ext.NewExtendedTime(ts, ext.TimestampTzKindType, "2006-01-02T15:04:05.999-07:00"), "embeddedDocument": `{"field1":"value1","field2":"value2"}`, "embeddedMap": `{"foo":"bar","hello":"world","pi":3.14159}`, "binary": `{"$binary":{"base64":"YmluYXJ5IGRhdGE=","subType":"00"}}`, diff --git a/lib/debezium/converters/decimal.go b/lib/debezium/converters/decimal.go index 1a460eda..e7342184 100644 --- a/lib/debezium/converters/decimal.go +++ b/lib/debezium/converters/decimal.go @@ -2,21 +2,36 @@ package converters import ( "fmt" + "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/debezium/converters" "github.com/artie-labs/transfer/lib/typing" "github.com/cockroachdb/apd/v3" ) -type decimalConverter struct { +// encodeDecimalWithScale is used to encode a [*apd.Decimal] to `org.apache.kafka.connect.data.Decimal` +// using a specific scale. +func encodeDecimalWithScale(decimal *apd.Decimal, scale int32) ([]byte, error) { + targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative. + if decimal.Exponent != targetExponent { + // Return an error if the scales are different, this maintains parity with `org.apache.kafka.connect.data.Decimal`. + // https://github.com/a0x8o/kafka/blob/54eff6af115ee647f60129f2ce6a044cb17215d0/connect/api/src/main/java/org/apache/kafka/connect/data/Decimal.java#L69 + return nil, fmt.Errorf("value scale (%d) is different from schema scale (%d)", -decimal.Exponent, scale) + } + bytes, _ := converters.EncodeDecimal(decimal) + return bytes, nil +} + +type DecimalConverter struct { scale uint16 precision *int } -func NewDecimalConverter(scale uint16, precision *int) decimalConverter { - return decimalConverter{scale: scale, precision: precision} +func NewDecimalConverter(scale uint16, precision *int) DecimalConverter { + return DecimalConverter{scale: scale, precision: precision} } -func (d decimalConverter) ToField(name string) debezium.Field { +func (d DecimalConverter) ToField(name string) debezium.Field { field := debezium.Field{ FieldName: name, Type: debezium.Bytes, @@ -33,7 +48,7 @@ func (d decimalConverter) ToField(name string) debezium.Field { return field } -func (d decimalConverter) Convert(value any) (any, error) { +func (d DecimalConverter) Convert(value any) (any, error) { stringValue, err := typing.AssertType[string](value) if err != nil { return nil, err @@ -44,7 +59,7 @@ func (d decimalConverter) Convert(value any) (any, error) { return nil, fmt.Errorf(`unable to use %q as a decimal: %w`, stringValue, err) } - return debezium.EncodeDecimalWithScale(decimal, int32(d.scale)), nil + return encodeDecimalWithScale(decimal, int32(d.scale)) } type VariableNumericConverter struct{} @@ -68,9 +83,9 @@ func (VariableNumericConverter) Convert(value any) (any, error) { return nil, fmt.Errorf(`unable to use %q as a decimal: %w`, stringValue, err) } - bytes, scale := debezium.EncodeDecimal(decimal) + bytes, scale := converters.EncodeDecimal(decimal) return map[string]any{ - "scale": int32(scale), + "scale": scale, "value": bytes, }, nil } diff --git a/lib/debezium/converters/decimal_test.go b/lib/debezium/converters/decimal_test.go index 74d266c4..8446ff1c 100644 --- a/lib/debezium/converters/decimal_test.go +++ b/lib/debezium/converters/decimal_test.go @@ -5,10 +5,137 @@ import ( "testing" "github.com/artie-labs/transfer/lib/debezium" - "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/debezium/converters" + "github.com/artie-labs/transfer/lib/numbers" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/stretchr/testify/assert" ) +func TestEncodeDecimalWithScale(t *testing.T) { + mustEncodeAndDecodeDecimal := func(value string, scale int32) string { + bytes, err := encodeDecimalWithScale(numbers.MustParseDecimal(value), scale) + assert.NoError(t, err) + return converters.DecodeDecimal(bytes, scale).String() + } + + mustReturnError := func(value string, scale int32) error { + _, err := encodeDecimalWithScale(numbers.MustParseDecimal(value), scale) + assert.Error(t, err) + return err + } + + // Whole numbers: + for i := range 100_000 { + strValue := fmt.Sprint(i) + assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0)) + if i != 0 { + strValue := "-" + strValue + assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0)) + } + } + + // Scale of 15 that is equal to the amount of decimal places in the value: + assert.Equal(t, "145.183000000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 15)) + assert.Equal(t, "-145.183000000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 15)) + // If scale is smaller than the amount of decimal places then an error should be returned: + assert.ErrorContains(t, mustReturnError("145.183000000000000", 14), "value scale (15) is different from schema scale (14)") + // If scale is larger than the amount of decimal places then an error should be returned: + assert.ErrorContains(t, mustReturnError("-145.183000000000005", 16), "value scale (15) is different from schema scale (16)") + + assert.Equal(t, "-9063701308.217222135", mustEncodeAndDecodeDecimal("-9063701308.217222135", 9)) + assert.Equal(t, "-74961544796695.89960242", mustEncodeAndDecodeDecimal("-74961544796695.89960242", 8)) + + testCases := []struct { + name string + value string + scale int32 + }{ + { + name: "0 scale", + value: "5", + }, + { + name: "2 scale", + value: "23131319.99", + scale: 2, + }, + { + name: "5 scale", + value: "9.12345", + scale: 5, + }, + { + name: "negative number", + value: "-105.2813669", + scale: 7, + }, + // Longitude #1 + { + name: "long 1", + value: "-75.765611", + scale: 6, + }, + // Latitude #1 + { + name: "lat", + value: "40.0335495", + scale: 7, + }, + // Long #2 + { + name: "long 2", + value: "-119.65575", + scale: 5, + }, + { + name: "lat 2", + value: "36.3303", + scale: 4, + }, + { + name: "long 3", + value: "-81.76254098", + scale: 8, + }, + { + name: "amount", + value: "6408.355", + scale: 3, + }, + { + name: "total", + value: "1.05", + scale: 2, + }, + { + name: "negative number: 2^16 - 255", + value: "-65281", + scale: 0, + }, + { + name: "negative number: 2^16 - 1", + value: "-65535", + scale: 0, + }, + { + name: "number with a scale of 15", + value: "0.000022998904125", + scale: 15, + }, + { + name: "number with a scale of 15", + value: "145.183000000000000", + scale: 15, + }, + } + + for _, testCase := range testCases { + actual := mustEncodeAndDecodeDecimal(testCase.value, testCase.scale) + assert.Equal(t, testCase.value, actual, testCase.name) + } +} + func TestDecimalConverter_ToField(t *testing.T) { { // Without precision @@ -25,7 +152,7 @@ func TestDecimalConverter_ToField(t *testing.T) { } { // With precision - converter := NewDecimalConverter(2, ptr.ToInt(3)) + converter := NewDecimalConverter(2, typing.ToPtr(3)) expected := debezium.Field{ Type: "bytes", FieldName: "col", @@ -57,9 +184,10 @@ func TestDecimalConverter_Convert(t *testing.T) { assert.NoError(t, err) bytes, ok := converted.([]byte) assert.True(t, ok) - actualValue, err := converter.ToField("").DecodeDecimal(bytes) + + actualValue, err := converter.ToField("").ParseValue(bytes) assert.NoError(t, err) - assert.Equal(t, "1.23", fmt.Sprint(actualValue)) + assert.Equal(t, "1.23", actualValue.(*decimal.Decimal).String()) } } @@ -95,8 +223,9 @@ func TestVariableNumericConverter_Convert(t *testing.T) { converted, err := converter.Convert("12.34") assert.NoError(t, err) assert.Equal(t, map[string]any{"scale": int32(2), "value": []byte{0x4, 0xd2}}, converted) - actualValue, err := converter.ToField("").DecodeDebeziumVariableDecimal(converted) + + actualValue, err := converters.NewVariableDecimal().Convert(converted) assert.NoError(t, err) - assert.Equal(t, "12.34", actualValue.String()) + assert.Equal(t, "12.34", actualValue.(*decimal.Decimal).String()) } } diff --git a/lib/debezium/converters/money.go b/lib/debezium/converters/money.go index d634010d..2c64a620 100644 --- a/lib/debezium/converters/money.go +++ b/lib/debezium/converters/money.go @@ -56,5 +56,5 @@ func (m MoneyConverter) Convert(value any) (any, error) { return nil, fmt.Errorf(`unable to use %q as a money value: %w`, valString, err) } - return debezium.EncodeDecimalWithScale(decimal, int32(m.Scale())), nil + return encodeDecimalWithScale(decimal, int32(m.Scale())) } diff --git a/lib/debezium/converters/money_test.go b/lib/debezium/converters/money_test.go index 6e6d6e6c..a5c0a14f 100644 --- a/lib/debezium/converters/money_test.go +++ b/lib/debezium/converters/money_test.go @@ -1,10 +1,11 @@ package converters import ( + "github.com/artie-labs/transfer/lib/typing" "testing" - "github.com/artie-labs/reader/lib/ptr" transferDbz "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func TestMoney_Scale(t *testing.T) { { // Specified converter := MoneyConverter{ - ScaleOverride: ptr.ToUint16(3), + ScaleOverride: typing.ToPtr(uint16(3)), } assert.Equal(t, uint16(3), converter.Scale()) } @@ -41,9 +42,12 @@ func TestMoneyConverter_Convert(t *testing.T) { decodeValue := func(value any) string { bytes, ok := value.([]byte) assert.True(t, ok) - val, err := decimalField.DecodeDecimal(bytes) + + valueConverter, err := decimalField.ToValueConverter() + assert.NoError(t, err) + val, err := valueConverter.Convert(bytes) assert.NoError(t, err) - return val.String() + return val.(*decimal.Decimal).String() } { // Converter where mutateString is true @@ -66,11 +70,9 @@ func TestMoneyConverter_Convert(t *testing.T) { assert.Equal(t, "1234.56", decodeValue(converted)) } { - // string with $, comma, and no cents - converted, err := converter.Convert("$1000,234") - assert.NoError(t, err) - assert.Equal(t, []byte{0x5, 0xf6, 0x3c, 0x68}, converted) - assert.Equal(t, "1000234.00", decodeValue(converted)) + // string with missing cents + _, err := converter.Convert("$1000,234") + assert.ErrorContains(t, err, "value scale (0) is different from schema scale (2)") } { // Malformed string - empty string. diff --git a/lib/debezium/converters/time.go b/lib/debezium/converters/time.go index b8fbe41c..ba63011b 100644 --- a/lib/debezium/converters/time.go +++ b/lib/debezium/converters/time.go @@ -171,7 +171,7 @@ func (ZonedTimestampConverter) ToField(name string) debezium.Field { return debezium.Field{ FieldName: name, Type: debezium.String, - DebeziumType: debezium.DateTimeWithTimezone, + DebeziumType: debezium.ZonedTimestamp, } } @@ -188,7 +188,10 @@ func (ZonedTimestampConverter) Convert(value any) (any, error) { return nil, nil } - return timeValue.Format(time.RFC3339Nano), nil + // A string representation of a timestamp with timezone information, where the timezone is GMT. + // This layout supports upto microsecond precision. + layout := "2006-01-02T15:04:05.999999Z" + return timeValue.UTC().Format(layout), nil } type YearConverter struct{} diff --git a/lib/debezium/converters/time_test.go b/lib/debezium/converters/time_test.go index cd7db878..503370b5 100644 --- a/lib/debezium/converters/time_test.go +++ b/lib/debezium/converters/time_test.go @@ -1,32 +1,23 @@ package converters import ( - "fmt" "math" "testing" "time" - "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/debezium/converters" + "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/ext" "github.com/stretchr/testify/assert" ) func parseUsingTransfer(converter ValueConverter, value int64) (*ext.ExtendedTime, error) { - if transferConverter := converter.ToField("foo").ToValueConverter(); transferConverter != nil { - val, err := transferConverter.Convert(value) - if err != nil { - return nil, err - } - - extTime, isOk := val.(*ext.ExtendedTime) - if !isOk { - return nil, fmt.Errorf("expected *ext.ExtendedTime got %T", val) - } - - return extTime, nil + parsedValue, err := converter.ToField("foo").ParseValue(value) + if err != nil { + return nil, err } - return debezium.FromDebeziumTypeToTime(converter.ToField("foo").DebeziumType, value) + return typing.AssertType[*ext.ExtendedTime](parsedValue) } func TestTimeConverter_Convert(t *testing.T) { @@ -92,8 +83,8 @@ func TestMicroTimeConverter_Convert(t *testing.T) { assert.NoError(t, err) transferValue, err := parseUsingTransfer(converter, value.(int64)) assert.NoError(t, err) - assert.Equal(t, time.Date(1970, time.January, 1, 1, 2, 3, 0, time.UTC), transferValue.Time) - assert.Equal(t, ext.TimeKindType, transferValue.NestedKind.Type) + assert.Equal(t, time.Date(1970, time.January, 1, 1, 2, 3, 0, time.UTC), transferValue.GetTime()) + assert.Equal(t, ext.TimeKindType, transferValue.GetNestedKind().Type) } } @@ -202,8 +193,8 @@ func TestDateConverter_Convert(t *testing.T) { assert.NoError(t, err) transferValue, err := parseUsingTransfer(converter, int64(value.(int32))) assert.NoError(t, err) - assert.Equal(t, time.Date(2023, time.May, 3, 0, 0, 0, 0, time.UTC), transferValue.Time) - assert.Equal(t, ext.DateKindType, transferValue.NestedKind.Type) + assert.Equal(t, time.Date(2023, time.May, 3, 0, 0, 0, 0, time.UTC), transferValue.GetTime()) + assert.Equal(t, ext.DateKindType, transferValue.GetNestedKind().Type) } } @@ -254,8 +245,8 @@ func TestMicroTimestampConverter_Convert(t *testing.T) { assert.NoError(t, err) transferValue, err := parseUsingTransfer(converter, value.(int64)) assert.NoError(t, err) - assert.Equal(t, timeValue, transferValue.Time) - assert.Equal(t, ext.DateTimeKindType, transferValue.NestedKind.Type) + assert.Equal(t, timeValue, transferValue.GetTime()) + assert.Equal(t, ext.TimestampTzKindType, transferValue.GetNestedKind().Type) } } @@ -295,9 +286,51 @@ func TestZonedTimestampConverter_Convert(t *testing.T) { } { // time.Time - value, err := converter.Convert(time.Date(2001, 2, 3, 4, 5, 0, 0, time.UTC)) + _ts := time.Date(2001, 2, 3, 4, 5, 0, 0, time.UTC) + value, err := converter.Convert(_ts) assert.NoError(t, err) assert.Equal(t, "2001-02-03T04:05:00Z", value) + + // Check Transfer to ensure no precision loss + ts, err := converters.ZonedTimestamp{}.Convert(value) + assert.NoError(t, err) + assert.Equal(t, _ts, ts.(*ext.ExtendedTime).GetTime()) + } + { + // time.Time (ms) + _ts := time.Date(2001, 2, 3, 4, 5, 1, 900000, time.UTC) + value, err := converter.Convert(_ts) + assert.NoError(t, err) + assert.Equal(t, "2001-02-03T04:05:01.0009Z", value) + + // Check Transfer to ensure no precision loss + ts, err := converters.ZonedTimestamp{}.Convert(value) + assert.NoError(t, err) + assert.Equal(t, _ts, ts.(*ext.ExtendedTime).GetTime()) + } + { + // time.Time (microseconds) + _ts := time.Date(2001, 2, 3, 4, 5, 1, 909000, time.UTC) + value, err := converter.Convert(_ts) + assert.NoError(t, err) + assert.Equal(t, "2001-02-03T04:05:01.000909Z", value) + + // Check Transfer to ensure no precision loss + ts, err := converters.ZonedTimestamp{}.Convert(value) + assert.NoError(t, err) + assert.Equal(t, _ts, ts.(*ext.ExtendedTime).GetTime()) + } + { + // Different timezone + _ts := time.Date(2001, 2, 3, 4, 5, 0, 0, time.FixedZone("CET", 1*60*60)) + value, err := converter.Convert(_ts) + assert.NoError(t, err) + assert.Equal(t, "2001-02-03T03:05:00Z", value) + + // Check Transfer to ensure no precision loss + ts, err := converters.ZonedTimestamp{}.Convert(value) + assert.NoError(t, err) + assert.Equal(t, _ts.UTC(), ts.(*ext.ExtendedTime).GetTime()) } } diff --git a/lib/mongo/change_event.go b/lib/mongo/change_event.go index 81386d96..f9eaf219 100644 --- a/lib/mongo/change_event.go +++ b/lib/mongo/change_event.go @@ -88,12 +88,17 @@ func NewChangeEvent(rawChangeEvent bson.M) (*ChangeEvent, error) { fullDocumentBeforeChange, isOk := rawChangeEvent["fullDocumentBeforeChange"] if isOk { - castedFullDocumentBeforeChange, isOk := fullDocumentBeforeChange.(bson.M) - if !isOk { - return nil, fmt.Errorf("expected fullDocumentBeforeChange to be bson.M, got: %T", fullDocumentBeforeChange) + switch castedFullDoc := fullDocumentBeforeChange.(type) { + case bson.M: + changeEvent.fullDocumentBeforeChange = &castedFullDoc + case nil: + // This may happen if the row was purged before we can read it + changeEvent.fullDocumentBeforeChange = &bson.M{ + "_id": objectID, + } + default: + return nil, fmt.Errorf("expected fullDocumentBeforeChange to be bson.M or nil, got: %T", fullDoc) } - - changeEvent.fullDocumentBeforeChange = &castedFullDocumentBeforeChange } return changeEvent, nil diff --git a/lib/mongo/message.go b/lib/mongo/message.go index b9726a2f..1e4b33bb 100644 --- a/lib/mongo/message.go +++ b/lib/mongo/message.go @@ -7,12 +7,12 @@ import ( "github.com/artie-labs/transfer/lib/cdc/mongo" "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/typing" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "github.com/artie-labs/reader/config" "github.com/artie-labs/reader/lib" - "github.com/artie-labs/reader/lib/ptr" ) type Message struct { @@ -97,7 +97,7 @@ func ParseMessage(after bson.M, before *bson.M, op string) (*Message, error) { return nil, fmt.Errorf("failed to marshal document to JSON extended: %w", err) } - msg.beforeJSONExtendedString = ptr.ToPtr(string(beforeRow)) + msg.beforeJSONExtendedString = typing.ToPtr(string(beforeRow)) } return msg, nil diff --git a/lib/mongo/message_test.go b/lib/mongo/message_test.go index 78317dd5..70af4a74 100644 --- a/lib/mongo/message_test.go +++ b/lib/mongo/message_test.go @@ -3,6 +3,7 @@ package mongo import ( "encoding/json" "fmt" + "github.com/artie-labs/transfer/lib/typing/ext" "testing" "time" @@ -31,7 +32,7 @@ func TestParseMessagePartitionKey(t *testing.T) { assert.NoError(t, err) var dbz transferMongo.Debezium - pkMap, err := dbz.GetPrimaryKey(rawMsgBytes, &kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) + pkMap, err := dbz.GetPrimaryKey(rawMsgBytes, kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) assert.NoError(t, err) assert.Equal(t, "507f1f77bcf86cd799439011", pkMap["_id"]) } @@ -88,7 +89,7 @@ func TestParseMessage(t *testing.T) { assert.NoError(t, err) var dbz transferMongo.Debezium - pkMap, err := dbz.GetPrimaryKey(rawPkBytes, &kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) + pkMap, err := dbz.GetPrimaryKey(rawPkBytes, kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}) assert.NoError(t, err) rawMsgBytes, err := json.Marshal(rawMsg.Event()) @@ -105,13 +106,13 @@ func TestParseMessage(t *testing.T) { "decimal": "1234.5", "subDocument": `{"nestedString":"Nested value"}`, "array": []any{"apple", "banana", "cherry"}, - "datetime": "2024-02-13T20:37:48+00:00", + "datetime": ext.NewExtendedTime(time.Date(2024, time.February, 13, 20, 37, 48, 0, time.UTC), ext.TimestampTzKindType, "2006-01-02T15:04:05.999-07:00"), "trueValue": true, "falseValue": false, "nullValue": nil, } - actualKVMap, err := kvMap.GetData(pkMap, &kafkalib.TopicConfig{}) + actualKVMap, err := kvMap.GetData(pkMap, kafkalib.TopicConfig{}) assert.NoError(t, err) for expectedKey, expectedVal := range expectedMap { actualVal, isOk := actualKVMap[expectedKey] diff --git a/lib/mssql/schema/schema_test.go b/lib/mssql/schema/schema_test.go index b11e9542..7bf726e0 100644 --- a/lib/mssql/schema/schema_test.go +++ b/lib/mssql/schema/schema_test.go @@ -1,12 +1,9 @@ package schema import ( - "testing" - - ptr2 "github.com/artie-labs/reader/lib/ptr" - - "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" + "testing" ) func TestParseColumnDataType(t *testing.T) { @@ -63,7 +60,7 @@ func TestParseColumnDataType(t *testing.T) { { // valid for _, colKind := range []string{"numeric", "decimal"} { - dataType, opts, err := ParseColumnDataType(colKind, ptr.ToInt(1), ptr2.ToUint16(2), nil) + dataType, opts, err := ParseColumnDataType(colKind, typing.ToPtr(1), typing.ToPtr(uint16(2)), nil) assert.NoError(t, err, colKind) assert.NotNil(t, opts, colKind) assert.Equal(t, Numeric, dataType, colKind) @@ -74,7 +71,7 @@ func TestParseColumnDataType(t *testing.T) { { // invalid, precision is missing for _, colKind := range []string{"numeric", "decimal"} { - dataType, opts, err := ParseColumnDataType(colKind, nil, ptr2.ToUint16(2), nil) + dataType, opts, err := ParseColumnDataType(colKind, nil, typing.ToPtr(uint16(2)), nil) assert.ErrorContains(t, err, "expected precision and scale to be not-nil", colKind) assert.Nil(t, opts, colKind) assert.Equal(t, -1, int(dataType), colKind) @@ -86,7 +83,7 @@ func TestParseColumnDataType(t *testing.T) { { // Default for i := 0; i <= 3; i++ { - dataType, opts, err := ParseColumnDataType("time", nil, nil, ptr.ToInt(i)) + dataType, opts, err := ParseColumnDataType("time", nil, nil, typing.ToPtr(i)) assert.NoError(t, err, i) assert.Nil(t, opts, i) assert.Equal(t, Time, dataType, i) @@ -95,7 +92,7 @@ func TestParseColumnDataType(t *testing.T) { { // Micro for i := 4; i <= 6; i++ { - dataType, opts, err := ParseColumnDataType("time", nil, nil, ptr.ToInt(i)) + dataType, opts, err := ParseColumnDataType("time", nil, nil, typing.ToPtr(i)) assert.NoError(t, err, i) assert.Nil(t, opts, i) assert.Equal(t, TimeMicro, dataType, i) @@ -103,7 +100,7 @@ func TestParseColumnDataType(t *testing.T) { } { // Nano - dataType, opts, err := ParseColumnDataType("time", nil, nil, ptr.ToInt(7)) + dataType, opts, err := ParseColumnDataType("time", nil, nil, typing.ToPtr(7)) assert.NoError(t, err) assert.Nil(t, opts) assert.Equal(t, TimeNano, dataType) @@ -139,7 +136,7 @@ func TestParseColumnDataType(t *testing.T) { { // Default for i := 0; i <= 3; i++ { - dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, ptr.ToInt(i)) + dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, typing.ToPtr(i)) assert.NoError(t, err, i) assert.Nil(t, opts, i) assert.Equal(t, Datetime2, dataType, i) @@ -148,7 +145,7 @@ func TestParseColumnDataType(t *testing.T) { { // Micro for i := 4; i <= 6; i++ { - dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, ptr.ToInt(i)) + dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, typing.ToPtr(i)) assert.NoError(t, err, i) assert.Nil(t, opts, i) assert.Equal(t, Datetime2Micro, dataType, i) @@ -156,7 +153,7 @@ func TestParseColumnDataType(t *testing.T) { } { // nano - dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, ptr.ToInt(7)) + dataType, opts, err := ParseColumnDataType("datetime2", nil, nil, typing.ToPtr(7)) assert.NoError(t, err) assert.Nil(t, opts) assert.Equal(t, Datetime2Nano, dataType) diff --git a/lib/mysql/schema/convert.go b/lib/mysql/schema/convert.go index 95201be2..bf8d41e3 100644 --- a/lib/mysql/schema/convert.go +++ b/lib/mysql/schema/convert.go @@ -12,22 +12,34 @@ import ( const DateTimeFormat = "2006-01-02 15:04:05.999999999" // ConvertValue takes a value returned from the MySQL driver and converts it to a native Go type. -func ConvertValue(value any, colType DataType) (any, error) { +func ConvertValue(value any, colType DataType, opts *Opts) (any, error) { if value == nil { return nil, nil } switch colType { case Bit: + if opts == nil || opts.Size == nil { + return nil, fmt.Errorf("bit column has no size") + } + // Bits castValue, ok := value.([]byte) if !ok { return nil, fmt.Errorf("expected []byte got %T for value: %v", value, value) } - if len(castValue) != 1 || castValue[0] > 1 { - return nil, fmt.Errorf("bit value is invalid: %v", value) + + switch *opts.Size { + case 0: + return nil, fmt.Errorf("bit column has size 0, valid range is [1, 64]") + case 1: + if len(castValue) != 1 || castValue[0] > 1 { + return nil, fmt.Errorf("bit value is invalid: %v", value) + } + return castValue[0] == 1, nil + default: + return castValue, nil } - return castValue[0] == 1, nil case Boolean: castVal, ok := value.(int64) if !ok { @@ -177,7 +189,7 @@ func ConvertValues(values []any, cols []Column) error { for i, value := range values { col := cols[i] - convertedVal, err := ConvertValue(value, col.Type) + convertedVal, err := ConvertValue(value, col.Type, col.Opts) if err != nil { return fmt.Errorf("failed to convert value for column %q: %w", col.Name, err) } diff --git a/lib/mysql/schema/convert_test.go b/lib/mysql/schema/convert_test.go index 17695e77..cf50ef19 100644 --- a/lib/mysql/schema/convert_test.go +++ b/lib/mysql/schema/convert_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" ) @@ -20,6 +21,7 @@ func TestConvertValue(t *testing.T) { tests := []struct { name string dataType DataType + opts *Opts value any expected any expectedErr string @@ -34,24 +36,28 @@ func TestConvertValue(t *testing.T) { name: "bit - 0 value", dataType: Bit, value: []byte{byte(0)}, + opts: &Opts{Size: typing.ToPtr(1)}, expected: false, }, { name: "bit - 1 value", dataType: Bit, value: []byte{byte(1)}, + opts: &Opts{Size: typing.ToPtr(1)}, expected: true, }, { name: "bit - 2 value", dataType: Bit, value: []byte{byte(2)}, + opts: &Opts{Size: typing.ToPtr(1)}, expectedErr: "bit value is invalid", }, { name: "bit - 2 bytes", dataType: Bit, value: []byte{byte(1), byte(1)}, + opts: &Opts{Size: typing.ToPtr(1)}, expectedErr: "bit value is invalid", }, { @@ -279,7 +285,7 @@ func TestConvertValue(t *testing.T) { } for _, tc := range tests { - value, err := ConvertValue(tc.value, tc.dataType) + value, err := ConvertValue(tc.value, tc.dataType, tc.opts) if tc.expectedErr == "" { assert.NoError(t, err, tc.name) assert.Equal(t, tc.expected, value, tc.name) @@ -293,7 +299,7 @@ func TestConvertValues(t *testing.T) { columns := []Column{ {Name: "a", Type: Int}, {Name: "b", Type: Varchar}, - {Name: "c", Type: Bit}, + {Name: "c", Type: Bit, Opts: &Opts{Size: typing.ToPtr(1)}}, } { diff --git a/lib/mysql/schema/schema.go b/lib/mysql/schema/schema.go index 12ccc68e..eed96f0d 100644 --- a/lib/mysql/schema/schema.go +++ b/lib/mysql/schema/schema.go @@ -4,16 +4,14 @@ import ( "database/sql" "errors" "fmt" - ptr2 "github.com/artie-labs/reader/lib/ptr" "log/slog" "strconv" "strings" - "github.com/artie-labs/transfer/lib/ptr" - "github.com/artie-labs/reader/lib/rdbms" "github.com/artie-labs/reader/lib/rdbms/column" "github.com/artie-labs/reader/lib/rdbms/primary_key" + "github.com/artie-labs/transfer/lib/typing" ) type DataType int @@ -172,13 +170,18 @@ func parseColumnDataType(originalS string) (DataType, *Opts, error) { if err != nil { return -1, nil, fmt.Errorf("failed to parse scale value %q: %w", s, err) } - return Decimal, &Opts{Precision: ptr.ToInt(precision), Scale: ptr2.ToUint16(uint16(scale))}, nil + return Decimal, &Opts{Precision: typing.ToPtr(precision), Scale: typing.ToPtr(uint16(scale))}, nil case "float": return Float, nil, nil case "double": return Double, nil, nil case "bit": - return Bit, nil, nil + size, err := strconv.Atoi(metadata) + if err != nil { + return -1, nil, fmt.Errorf("failed to parse metadata value %q: %w", s, err) + } + + return Bit, &Opts{Size: typing.ToPtr(size)}, nil case "date": return Date, nil, nil case "datetime": @@ -196,12 +199,12 @@ func parseColumnDataType(originalS string) (DataType, *Opts, error) { if err != nil { return -1, nil, fmt.Errorf("failed to parse varchar size: %w", err) } - return Varchar, &Opts{Size: ptr.ToInt(size)}, nil + return Varchar, &Opts{Size: typing.ToPtr(size)}, nil case "binary": return Binary, nil, nil case "varbinary": return Varbinary, nil, nil - case "blob": + case "blob", "tinyblob", "mediumblob", "longblob": return Blob, nil, nil case "text": return Text, nil, nil diff --git a/lib/mysql/schema/schema_test.go b/lib/mysql/schema/schema_test.go index f399a205..2304a344 100644 --- a/lib/mysql/schema/schema_test.go +++ b/lib/mysql/schema/schema_test.go @@ -1,12 +1,9 @@ package schema import ( - "testing" - - ptr2 "github.com/artie-labs/reader/lib/ptr" - - "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" + "testing" ) func TestQuoteIdentifier(t *testing.T) { @@ -15,80 +12,86 @@ func TestQuoteIdentifier(t *testing.T) { } func TestParseColumnDataType(t *testing.T) { - testCases := []struct { - input string - expectedType DataType - expectedOpts *Opts - expectedErr string - }{ - { - input: "int", - expectedType: Int, - }, - { - input: "tinyint(1)", - expectedType: Boolean, - }, - { - input: "varchar(255)", - expectedType: Varchar, - expectedOpts: &Opts{Size: ptr.ToInt(255)}, - }, + { + // Invalid { - input: "decimal(5,2)", - expectedType: Decimal, - expectedOpts: &Opts{ - Precision: ptr.ToInt(5), - Scale: ptr2.ToUint16(2), - }, - }, + _, _, err := parseColumnDataType("int(10 unsigned") + assert.ErrorContains(t, err, `malformed data type: "int(10 unsigned"`) + } { - input: "int(10) unsigned", - expectedType: BigInt, - expectedOpts: nil, - }, + _, _, err := parseColumnDataType("foo") + assert.ErrorContains(t, err, `unknown data type: "foo"`) + } { - input: "tinyint unsigned", - expectedType: SmallInt, - expectedOpts: nil, - }, + _, _, err := parseColumnDataType("varchar(") + assert.ErrorContains(t, err, `malformed data type: "varchar("`) + } + } + { + // Integers { - input: "smallint unsigned", - expectedType: Int, - expectedOpts: nil, - }, + // int + dataType, _, err := parseColumnDataType("int") + assert.NoError(t, err) + assert.Equal(t, Int, dataType) + } { - input: "mediumint unsigned", - expectedType: Int, - expectedOpts: nil, - }, + // int unsigned + dataType, _, err := parseColumnDataType("int unsigned") + assert.NoError(t, err) + assert.Equal(t, BigInt, dataType) + } { - input: "int unsigned", - expectedType: BigInt, - expectedOpts: nil, - }, + // int(10) unsigned + dataType, _, err := parseColumnDataType("int(10) unsigned") + assert.NoError(t, err) + assert.Equal(t, BigInt, dataType) + } { - input: "int(10 unsigned", - expectedErr: `malformed data type: "int(10 unsigned"`, - }, + // tinyint + dataType, _, err := parseColumnDataType("tinyint") + assert.NoError(t, err) + assert.Equal(t, TinyInt, dataType) + } { - input: "foo", - expectedErr: `unknown data type: "foo"`, - }, + // tinyint unsigned + dataType, _, err := parseColumnDataType("tinyint unsigned") + assert.NoError(t, err) + assert.Equal(t, SmallInt, dataType) + } { - input: "varchar(", - expectedErr: `malformed data type: "varchar("`, - }, + // mediumint unsigned + dataType, _, err := parseColumnDataType("mediumint unsigned") + assert.NoError(t, err) + assert.Equal(t, Int, dataType) + } } - - for _, testCase := range testCases { - colType, opts, err := parseColumnDataType(testCase.input) - if testCase.expectedErr == "" { + { + // tinyint(1) or boolean + dataType, _, err := parseColumnDataType("tinyint(1)") + assert.NoError(t, err) + assert.Equal(t, Boolean, dataType) + } + { + // String + dataType, opts, err := parseColumnDataType("varchar(255)") + assert.NoError(t, err) + assert.Equal(t, Varchar, dataType) + assert.Equal(t, &Opts{Size: typing.ToPtr(255)}, opts) + } + { + // Decimal + dataType, opts, err := parseColumnDataType("decimal(5,2)") + assert.NoError(t, err) + assert.Equal(t, Decimal, dataType) + assert.Equal(t, &Opts{Precision: typing.ToPtr(5), Scale: typing.ToPtr(uint16(2))}, opts) + } + { + // Blob + for _, blob := range []string{"blob", "tinyblob", "mediumblob", "longblob"} { + dataType, _, err := parseColumnDataType(blob) assert.NoError(t, err) - assert.Equal(t, testCase.expectedType, colType, testCase.input) - assert.Equal(t, testCase.expectedOpts, opts, testCase.input) - } else { - assert.ErrorContains(t, err, testCase.expectedErr, testCase.input) + assert.Equal(t, Blob, dataType, blob) } } } diff --git a/lib/postgres/schema/schema_test.go b/lib/postgres/schema/schema_test.go index b61864e9..4d38147a 100644 --- a/lib/postgres/schema/schema_test.go +++ b/lib/postgres/schema/schema_test.go @@ -1,12 +1,9 @@ package schema import ( - "testing" - - ptr2 "github.com/artie-labs/reader/lib/ptr" - - "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/typing" "github.com/stretchr/testify/assert" + "testing" ) func TestParseColumnDataType(t *testing.T) { @@ -81,8 +78,8 @@ func TestParseColumnDataType(t *testing.T) { { name: "numeric - with scale + precision", colKind: "numeric", - scale: ptr2.ToUint16(2), - precision: ptr.ToInt(3), + scale: typing.ToPtr(uint16(2)), + precision: typing.ToPtr(3), expectedDataType: Numeric, expectedOpts: &Opts{ Scale: 2, @@ -102,25 +99,25 @@ func TestParseColumnDataType(t *testing.T) { { name: "hstore", colKind: "user-defined", - udtName: ptr.ToString("hstore"), + udtName: typing.ToPtr("hstore"), expectedDataType: HStore, }, { name: "geometry", colKind: "user-defined", - udtName: ptr.ToString("geometry"), + udtName: typing.ToPtr("geometry"), expectedDataType: Geometry, }, { name: "geography", colKind: "user-defined", - udtName: ptr.ToString("geography"), + udtName: typing.ToPtr("geography"), expectedDataType: Geography, }, { name: "user-defined text", colKind: "user-defined", - udtName: ptr.ToString("foo"), + udtName: typing.ToPtr("foo"), expectedDataType: UserDefinedText, }, { diff --git a/lib/postgres/table.go b/lib/postgres/table.go index 91e4bbfa..ea4e2151 100644 --- a/lib/postgres/table.go +++ b/lib/postgres/table.go @@ -18,7 +18,7 @@ type Table struct { PrimaryKeys []string } -func LoadTable(db *sql.DB, _schema string, name string) (*Table, error) { +func LoadTable(db *sql.DB, _schema string, name string, primaryKeysOverride []string) (*Table, error) { tbl := &Table{ Name: name, Schema: _schema, @@ -29,8 +29,12 @@ func LoadTable(db *sql.DB, _schema string, name string) (*Table, error) { return nil, fmt.Errorf("failed to describe table %s.%s: %w", tbl.Schema, tbl.Name, err) } - if tbl.PrimaryKeys, err = schema.FetchPrimaryKeys(db, tbl.Schema, tbl.Name); err != nil { - return nil, fmt.Errorf("failed to retrieve primary keys: %w", err) + if len(primaryKeysOverride) > 0 { + tbl.PrimaryKeys = primaryKeysOverride + } else { + if tbl.PrimaryKeys, err = schema.FetchPrimaryKeys(db, tbl.Schema, tbl.Name); err != nil { + return nil, fmt.Errorf("failed to retrieve primary keys: %w", err) + } } return tbl, nil diff --git a/lib/ptr/ptr.go b/lib/ptr/ptr.go deleted file mode 100644 index 37259d63..00000000 --- a/lib/ptr/ptr.go +++ /dev/null @@ -1,9 +0,0 @@ -package ptr - -func ToUint16(val uint16) *uint16 { - return &val -} - -func ToPtr[T any](val T) *T { - return &val -} diff --git a/lib/rdbms/column/column.go b/lib/rdbms/column/column.go index 0a79405e..99e502a4 100644 --- a/lib/rdbms/column/column.go +++ b/lib/rdbms/column/column.go @@ -49,3 +49,26 @@ func FilterOutExcludedColumns[T ~int, O any](columns []Column[T, O], excludeName } return result, nil } + +// FilterForIncludedColumns returns a list of columns including only those that match `includeNames`. +// All primary keys must be included, else it'll return an error. +func FilterForIncludedColumns[T ~int, O any](columns []Column[T, O], includeNames []string, primaryKeys []string) ([]Column[T, O], error) { + if len(includeNames) == 0 { + return columns, nil + } + + // All primary keys must be included + for _, key := range primaryKeys { + if !slices.Contains(includeNames, key) { + return nil, fmt.Errorf("primary key column %q must be included", key) + } + } + + var result []Column[T, O] + for _, column := range columns { + if slices.Contains(includeNames, column.Name) { + result = append(result, column) + } + } + return result, nil +} diff --git a/lib/rdbms/column/column_test.go b/lib/rdbms/column/column_test.go index 405382f3..9899969b 100644 --- a/lib/rdbms/column/column_test.go +++ b/lib/rdbms/column/column_test.go @@ -158,3 +158,35 @@ func TestFilterOutExcludedColumns(t *testing.T) { assert.ErrorContains(t, err, `cannot exclude primary key column "bar"`) } } + +func TestFilterForIncludedColumns(t *testing.T) { + { + // Empty `includeNames` + value, err := FilterForIncludedColumns([]mockColumn{{Name: "foo"}}, []string{}, []string{}) + assert.NoError(t, err) + assert.Equal(t, value, []mockColumn{{Name: "foo"}}) + } + { + // Non-empty `includeNames`, included column is not in list + value, err := FilterForIncludedColumns([]mockColumn{{Name: "foo"}}, []string{"bar"}, []string{}) + assert.NoError(t, err) + assert.Equal(t, value, []mockColumn(nil)) + } + { + // Non-empty `includeNames`, included column is in list + value, err := FilterForIncludedColumns([]mockColumn{{Name: "foo"}, {Name: "bar"}}, []string{"bar"}, []string{}) + assert.NoError(t, err) + assert.Equal(t, value, []mockColumn{{Name: "bar"}}) + } + { + // Non-empty `includeNames`, included column is in list, primary key is not included + _, err := FilterForIncludedColumns([]mockColumn{{Name: "foo"}, {Name: "bar"}}, []string{"bar"}, []string{"foo"}) + assert.ErrorContains(t, err, `primary key column "foo" must be included`) + } + { + // Non-empty `includeNames`, included column is in list, primary key is included + value, err := FilterForIncludedColumns([]mockColumn{{Name: "foo"}, {Name: "bar"}}, []string{"foo", "bar"}, []string{"foo"}) + assert.NoError(t, err) + assert.Equal(t, value, []mockColumn{{Name: "foo"}, {Name: "bar"}}) + } +} diff --git a/lib/s3lib/s3lib_test.go b/lib/s3lib/s3lib_test.go index fe77db95..79fd7e72 100644 --- a/lib/s3lib/s3lib_test.go +++ b/lib/s3lib/s3lib_test.go @@ -1,73 +1,43 @@ package s3lib import ( - "testing" - "github.com/stretchr/testify/assert" + "testing" ) func TestBucketAndPrefixFromFilePath(t *testing.T) { - tcs := []struct { - name string - fp string - expectedBucket string - expectedPrefix string - expectedErr string - }{ - { - name: "valid path (w/ S3 prefix)", - fp: "s3://bucket/prefix", - expectedBucket: "bucket", - expectedPrefix: "prefix", - }, + { + // Invalid { - name: "valid path (w/ S3 prefix) with trailing slash", - fp: "s3://bucket/prefix/", - expectedBucket: "bucket", - expectedPrefix: "prefix/", - }, - { - name: "valid path (w/ S3 prefix) with multiple slashes", - fp: "s3://bucket/prefix/with/multiple/slashes", - expectedBucket: "bucket", - expectedPrefix: "prefix/with/multiple/slashes", - }, - // Without S3 prefix - { - name: "valid path (w/o S3 prefix)", - fp: "bucket/prefix", - expectedBucket: "bucket", - expectedPrefix: "prefix", - }, + // Empty string + bucket, prefix, err := BucketAndPrefixFromFilePath("") + assert.ErrorContains(t, err, "invalid S3 path, missing prefix") + assert.Empty(t, bucket) + assert.Empty(t, prefix) + } { - name: "valid path (w/o S3 prefix) with trailing slash", - fp: "bucket/prefix/", - expectedBucket: "bucket", - expectedPrefix: "prefix/", - }, + // Bucket only, no prefix + bucket, prefix, err := BucketAndPrefixFromFilePath("bucket") + assert.ErrorContains(t, err, "invalid S3 path, missing prefix") + assert.Empty(t, bucket) + assert.Empty(t, prefix) + } + } + { + // Valid { - name: "valid path (w/o S3 prefix) with multiple slashes", - fp: "bucket/prefix/with/multiple/slashes", - expectedBucket: "bucket", - expectedPrefix: "prefix/with/multiple/slashes", - }, + // No S3 prefix + bucket, prefix, err := BucketAndPrefixFromFilePath("bucket/prefix") + assert.NoError(t, err) + assert.Equal(t, "bucket", bucket) + assert.Equal(t, "prefix", prefix) + } { - name: "invalid path", - fp: "s3://bucket", - expectedErr: "invalid S3 path, missing prefix", - }, - } - - for _, tc := range tcs { - actualBucket, actualPrefix, actualErr := BucketAndPrefixFromFilePath(tc.fp) - if tc.expectedErr != "" { - assert.ErrorContains(t, actualErr, tc.expectedErr, tc.name) - } else { - assert.NoError(t, actualErr, tc.name) - - // Now check the actualBucket and prefix - assert.Equal(t, tc.expectedBucket, actualBucket, tc.name) - assert.Equal(t, tc.expectedPrefix, actualPrefix, tc.name) + // S3 prefix + bucket, prefix, err := BucketAndPrefixFromFilePath("s3://bucket/prefix") + assert.NoError(t, err) + assert.Equal(t, "bucket", bucket) + assert.Equal(t, "prefix", prefix) } } } diff --git a/sources/dynamodb/stream/shard.go b/sources/dynamodb/stream/shard.go index a12a817f..14d36010 100644 --- a/sources/dynamodb/stream/shard.go +++ b/sources/dynamodb/stream/shard.go @@ -7,7 +7,7 @@ import ( "time" "github.com/artie-labs/transfer/lib/jitter" - "github.com/artie-labs/transfer/lib/ptr" + "github.com/artie-labs/transfer/lib/typing" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams" "github.com/aws/aws-sdk-go-v2/service/dynamodbstreams/types" @@ -77,13 +77,13 @@ func (s *Store) processShard(ctx context.Context, shard types.Shard, writer writ } iteratorInput := &dynamodbstreams.GetShardIteratorInput{ - StreamArn: ptr.ToString(s.streamArn), + StreamArn: typing.ToPtr(s.streamArn), ShardId: shard.ShardId, ShardIteratorType: iteratorType, } if startingSequenceNumber != "" { - iteratorInput.SequenceNumber = ptr.ToString(startingSequenceNumber) + iteratorInput.SequenceNumber = typing.ToPtr(startingSequenceNumber) } iteratorOutput, err := s.streams.GetShardIterator(ctx, iteratorInput) @@ -97,7 +97,7 @@ func (s *Store) processShard(ctx context.Context, shard types.Shard, writer writ for shardIterator != nil { getRecordsInput := &dynamodbstreams.GetRecordsInput{ ShardIterator: shardIterator, - Limit: ptr.ToInt32(1000), + Limit: typing.ToPtr(int32(1000)), } getRecordsOutput, err := s.streams.GetRecords(ctx, getRecordsInput) diff --git a/sources/mongo/streaming.go b/sources/mongo/streaming.go index 5f4fc74d..b562a418 100644 --- a/sources/mongo/streaming.go +++ b/sources/mongo/streaming.go @@ -38,9 +38,9 @@ func newStreamingIterator(ctx context.Context, db *mongo.Database, cfg config.Mo // We only care about DMLs, the full list can be found here: https://www.mongodb.com/docs/manual/reference/change-events/ pipeline := mongo.Pipeline{ - {{"$match", bson.D{ - {"operationType", bson.D{ - {"$in", bson.A{"insert", "update", "delete", "replace"}}, + {{Key: "$match", Value: bson.D{ + {Key: "operationType", Value: bson.D{ + {Key: "$in", Value: bson.A{"insert", "update", "delete", "replace"}}, }}, }}}, } diff --git a/sources/mssql/adapter/adapter.go b/sources/mssql/adapter/adapter.go index 0b848fb4..1e7cc738 100644 --- a/sources/mssql/adapter/adapter.go +++ b/sources/mssql/adapter/adapter.go @@ -9,9 +9,9 @@ import ( "github.com/artie-labs/reader/lib/debezium/transformer" "github.com/artie-labs/reader/lib/mssql" "github.com/artie-labs/reader/lib/mssql/schema" - ptr2 "github.com/artie-labs/reader/lib/ptr" "github.com/artie-labs/reader/lib/rdbms/column" "github.com/artie-labs/reader/lib/rdbms/scan" + "github.com/artie-labs/transfer/lib/typing" ) const defaultErrorRetries = 10 @@ -31,11 +31,18 @@ func NewMSSQLAdapter(db *sql.DB, dbName string, tableCfg config.MSSQLTable) (MSS return MSSQLAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err) } + // Exclude columns (if any) from the table metadata columns, err := column.FilterOutExcludedColumns(table.Columns(), tableCfg.ExcludeColumns, table.PrimaryKeys()) if err != nil { return MSSQLAdapter{}, err } + // Include columns (if any) from the table metadata + columns, err = column.FilterForIncludedColumns(columns, tableCfg.IncludeColumns, table.PrimaryKeys()) + if err != nil { + return MSSQLAdapter{}, err + } + fieldConverters := make([]transformer.FieldConverter, len(columns)) for i, col := range columns { converter, err := valueConverterForType(col.Type, col.Opts) @@ -94,7 +101,7 @@ func valueConverterForType(dataType schema.DataType, opts *schema.Opts) (convert case schema.Money: return converters.MoneyConverter{ // MSSQL uses scale of 4 for money - ScaleOverride: ptr2.ToUint16(4), + ScaleOverride: typing.ToPtr(uint16(4)), }, nil case schema.String, schema.UniqueIdentifier: return converters.StringPassthrough{}, nil diff --git a/sources/mysql/adapter/adapter.go b/sources/mysql/adapter/adapter.go index a3ffaf69..793557ed 100644 --- a/sources/mysql/adapter/adapter.go +++ b/sources/mysql/adapter/adapter.go @@ -33,11 +33,18 @@ func NewMySQLAdapter(db *sql.DB, dbName string, tableCfg config.MySQLTable) (MyS return MySQLAdapter{}, fmt.Errorf("failed to load metadata for table %q: %w", tableCfg.Name, err) } + // Exclude columns (if any) from the table metadata columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys) if err != nil { return MySQLAdapter{}, err } + // Include columns (if any) from the table metadata + columns, err = column.FilterForIncludedColumns(columns, tableCfg.IncludeColumns, table.PrimaryKeys) + if err != nil { + return MySQLAdapter{}, err + } + return newMySQLAdapter(db, dbName, *table, columns, tableCfg.ToScannerConfig(defaultErrorRetries)) } @@ -83,7 +90,17 @@ func (m MySQLAdapter) PartitionKeys() []string { func valueConverterForType(d schema.DataType, opts *schema.Opts) (converters.ValueConverter, error) { switch d { - case schema.Bit, schema.Boolean: + case schema.Bit: + if opts == nil || opts.Size == nil { + return nil, fmt.Errorf("size is required for bit type") + } + + if *opts.Size == 1 { + return converters.BooleanPassthrough{}, nil + } + + return converters.BytesPassthrough{}, nil + case schema.Boolean: return converters.BooleanPassthrough{}, nil case schema.TinyInt, schema.SmallInt: return converters.Int16Passthrough{}, nil diff --git a/sources/mysql/adapter/adapter_test.go b/sources/mysql/adapter/adapter_test.go index 5e113950..65247b63 100644 --- a/sources/mysql/adapter/adapter_test.go +++ b/sources/mysql/adapter/adapter_test.go @@ -1,12 +1,10 @@ package adapter import ( + "github.com/artie-labs/transfer/lib/typing" "testing" - ptr2 "github.com/artie-labs/reader/lib/ptr" - "github.com/artie-labs/transfer/lib/debezium" - "github.com/artie-labs/transfer/lib/ptr" "github.com/stretchr/testify/assert" "github.com/artie-labs/reader/lib/mysql" @@ -70,13 +68,27 @@ func TestValueConverterForType(t *testing.T) { expectedErr: "unable get value converter for DataType(-1)", }, { - name: "bit", + name: "bit(1)", dataType: schema.Bit, + opts: &schema.Opts{ + Size: typing.ToPtr(1), + }, expected: debezium.Field{ Type: "boolean", FieldName: colName, }, }, + { + name: "bit(5)", + dataType: schema.Bit, + opts: &schema.Opts{ + Size: typing.ToPtr(5), + }, + expected: debezium.Field{ + Type: "bytes", + FieldName: colName, + }, + }, { name: "tinyint", dataType: schema.TinyInt, @@ -137,8 +149,8 @@ func TestValueConverterForType(t *testing.T) { name: "decimal", dataType: schema.Decimal, opts: &schema.Opts{ - Scale: ptr2.ToUint16(3), - Precision: ptr.ToInt(5), + Scale: typing.ToPtr(uint16(3)), + Precision: typing.ToPtr(5), }, expected: debezium.Field{ Type: "bytes", diff --git a/sources/postgres/adapter/adapter.go b/sources/postgres/adapter/adapter.go index 7492ae86..5a3f16c5 100644 --- a/sources/postgres/adapter/adapter.go +++ b/sources/postgres/adapter/adapter.go @@ -26,16 +26,23 @@ type PostgresAdapter struct { func NewPostgresAdapter(db *sql.DB, tableCfg config.PostgreSQLTable) (PostgresAdapter, error) { slog.Info("Loading metadata for table") - table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name) + table, err := postgres.LoadTable(db, tableCfg.Schema, tableCfg.Name, tableCfg.PrimaryKeysOverride) if err != nil { return PostgresAdapter{}, fmt.Errorf("failed to load metadata for table %s.%s: %w", tableCfg.Schema, tableCfg.Name, err) } + // Exclude columns (if any) from the table metadata columns, err := column.FilterOutExcludedColumns(table.Columns, tableCfg.ExcludeColumns, table.PrimaryKeys) if err != nil { return PostgresAdapter{}, err } + // Include columns (if any) from the table metadata + columns, err = column.FilterForIncludedColumns(columns, tableCfg.IncludeColumns, table.PrimaryKeys) + if err != nil { + return PostgresAdapter{}, err + } + fieldConverters := make([]transformer.FieldConverter, len(columns)) for i, col := range columns { converter, err := valueConverterForType(col.Type, col.Opts) diff --git a/sources/postgres/adapter/adapter_test.go b/sources/postgres/adapter/adapter_test.go index 80e5071c..6e924e90 100644 --- a/sources/postgres/adapter/adapter_test.go +++ b/sources/postgres/adapter/adapter_test.go @@ -9,6 +9,7 @@ import ( "github.com/artie-labs/reader/lib/postgres" "github.com/artie-labs/reader/lib/postgres/schema" "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/typing/decimal" ) func TestPostgresAdapter_TableName(t *testing.T) { @@ -262,10 +263,10 @@ func TestValueConverterForType_Convert(t *testing.T) { if tc.numericValue { bytes, ok := actualValue.([]byte) assert.True(t, ok) - field := converter.ToField(tc.col.Name) - val, err := field.DecodeDecimal(bytes) + + val, err := converter.ToField(tc.col.Name).ParseValue(bytes) assert.NoError(t, err, tc.name) - assert.Equal(t, tc.expectedValue, val.String(), tc.name) + assert.Equal(t, tc.expectedValue, val.(*decimal.Decimal).String(), tc.name) } else { assert.Equal(t, tc.expectedValue, actualValue, tc.name) } diff --git a/writers/transfer/writer.go b/writers/transfer/writer.go index a7b26888..717df136 100644 --- a/writers/transfer/writer.go +++ b/writers/transfer/writer.go @@ -27,7 +27,7 @@ type Writer struct { cfg config.Config statsD mtr.Client inMemDB *models.DatabaseData - tc *kafkalib.TopicConfig + tc kafkalib.TopicConfig destination destination.Baseline primaryKeys []string @@ -46,7 +46,7 @@ func NewWriter(cfg config.Config, statsD mtr.Client) (*Writer, error) { cfg: cfg, statsD: statsD, inMemDB: models.NewMemoryDB(), - tc: cfg.Kafka.TopicConfigs[0], + tc: *cfg.Kafka.TopicConfigs[0], } if utils.IsOutputBaseline(cfg) { @@ -241,7 +241,7 @@ func (w *Writer) OnComplete() error { } slog.Info("Running dedupe...", slog.String("table", tableName)) - tableID := w.destination.IdentifierFor(*w.tc, tableName) + tableID := w.destination.IdentifierFor(w.tc, tableName) start := time.Now() dwh, isOk := w.destination.(destination.DataWarehouse) diff --git a/writers/transfer/writer_test.go b/writers/transfer/writer_test.go index 0618a2c1..e09763fa 100644 --- a/writers/transfer/writer_test.go +++ b/writers/transfer/writer_test.go @@ -29,12 +29,8 @@ func TestWriter_MessageToEvent(t *testing.T) { assert.NoError(t, err) writer := Writer{ - cfg: transferCfg.Config{ - SharedTransferConfig: transferCfg.SharedTransferConfig{}, - }, - tc: &kafkalib.TopicConfig{ - CDCKeyFormat: kafkalib.JSONKeyFmt, - }, + cfg: transferCfg.Config{}, + tc: kafkalib.TopicConfig{CDCKeyFormat: kafkalib.JSONKeyFmt}, } evtOut, err := writer.messageToEvent(message) diff --git a/writers/writer.go b/writers/writer.go index f740ca99..38013990 100644 --- a/writers/writer.go +++ b/writers/writer.go @@ -55,8 +55,11 @@ func (w *Writer) Write(ctx context.Context, iter iterator.Iterator[[]lib.RawMess } } - if err := w.destinationWriter.OnComplete(); err != nil { - return 0, fmt.Errorf("failed running destination OnComplete: %w", err) + // Only run [OnComplete] if we wrote messages out. Otherwise, primary keys may not be loaded. + if count > 0 { + if err := w.destinationWriter.OnComplete(); err != nil { + return 0, fmt.Errorf("failed running destination OnComplete: %w", err) + } } return count, nil