diff --git a/lib/debezium/converters/decimal.go b/lib/debezium/converters/decimal.go index 1a460eda..08a960bc 100644 --- a/lib/debezium/converters/decimal.go +++ b/lib/debezium/converters/decimal.go @@ -2,11 +2,50 @@ package converters import ( "fmt" + "github.com/artie-labs/transfer/lib/debezium" "github.com/artie-labs/transfer/lib/typing" "github.com/cockroachdb/apd/v3" ) +// decimalWithNewExponent takes a [*apd.Decimal] and returns a new [*apd.Decimal] with a the given exponent. +// If the new exponent is less precise then the extra digits will be truncated. +func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal { + exponentDelta := newExponent - decimal.Exponent // Exponent is negative. + + if exponentDelta == 0 { + return new(apd.Decimal).Set(decimal) + } + + coefficient := new(apd.BigInt).Set(&decimal.Coeff) + + if exponentDelta < 0 { + multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil) + coefficient.Mul(coefficient, multiplier) + } else if exponentDelta > 0 { + divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil) + coefficient.Div(coefficient, divisor) + } + + return &apd.Decimal{ + Form: decimal.Form, + Negative: decimal.Negative, + Exponent: newExponent, + Coeff: *coefficient, + } +} + +// encodeDecimalWithScale is used to encode a [*apd.Decimal] to `org.apache.kafka.connect.data.Decimal` +// using a specific scale. +func encodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte { + targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative. + if decimal.Exponent != targetExponent { + decimal = decimalWithNewExponent(decimal, targetExponent) + } + bytes, _ := debezium.EncodeDecimal(decimal) + return bytes +} + type decimalConverter struct { scale uint16 precision *int @@ -44,7 +83,7 @@ func (d decimalConverter) Convert(value any) (any, error) { return nil, fmt.Errorf(`unable to use %q as a decimal: %w`, stringValue, err) } - return debezium.EncodeDecimalWithScale(decimal, int32(d.scale)), nil + return encodeDecimalWithScale(decimal, int32(d.scale)), nil } type VariableNumericConverter struct{} diff --git a/lib/debezium/converters/decimal_test.go b/lib/debezium/converters/decimal_test.go index 74d266c4..a82d605e 100644 --- a/lib/debezium/converters/decimal_test.go +++ b/lib/debezium/converters/decimal_test.go @@ -5,10 +5,167 @@ import ( "testing" "github.com/artie-labs/transfer/lib/debezium" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/ptr" + "github.com/cockroachdb/apd/v3" "github.com/stretchr/testify/assert" ) +func TestDecimalWithNewExponent(t *testing.T) { + assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 0), 0).Text('f')) + assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 1), 1).Text('f')) + assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 100), 0).Text('f')) + assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 0), 1).Text('f')) + assert.Equal(t, "0.0", decimalWithNewExponent(apd.New(0, 0), -1).Text('f')) + + // Same exponent: + assert.Equal(t, "12.349", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -3).Text('f')) + // More precise exponent: + assert.Equal(t, "12.3490", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -4).Text('f')) + assert.Equal(t, "12.34900", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -5).Text('f')) + // Lest precise exponent: + // Extra digits should be truncated rather than rounded. + assert.Equal(t, "12.34", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -2).Text('f')) + assert.Equal(t, "12.3", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -1).Text('f')) + assert.Equal(t, "12", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 0).Text('f')) + assert.Equal(t, "10", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 1).Text('f')) +} + +func TestEncodeDecimalWithScale(t *testing.T) { + mustEncodeAndDecodeDecimal := func(value string, scale int32) string { + bytes := encodeDecimalWithScale(numbers.MustParseDecimal(value), scale) + return debezium.DecodeDecimal(bytes, scale).String() + } + + // Whole numbers: + for i := range 100_000 { + strValue := fmt.Sprint(i) + assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0)) + if i != 0 { + strValue := "-" + strValue + assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0)) + } + } + + // Scale of 15 that is equal to the amount of decimal places in the value: + assert.Equal(t, "145.183000000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 15)) + assert.Equal(t, "-145.183000000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 15)) + // If scale is smaller than the amount of decimal places then the extra places should be truncated without rounding: + assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 14)) + assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000005", 14)) + assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000005", 14)) + assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000009", 14)) + assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000009", 14)) + assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 14)) + assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000001", 14)) + assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000001", 14)) + assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000004", 14)) + assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000004", 14)) + // If scale is larger than the amount of decimal places then the extra places should be padded with zeros: + assert.Equal(t, "145.1830000000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 16)) + assert.Equal(t, "-145.1830000000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 16)) + assert.Equal(t, "145.1830000000000010", mustEncodeAndDecodeDecimal("145.183000000000001", 16)) + assert.Equal(t, "-145.1830000000000010", mustEncodeAndDecodeDecimal("-145.183000000000001", 16)) + assert.Equal(t, "145.1830000000000040", mustEncodeAndDecodeDecimal("145.183000000000004", 16)) + assert.Equal(t, "-145.1830000000000040", mustEncodeAndDecodeDecimal("-145.183000000000004", 16)) + assert.Equal(t, "145.1830000000000050", mustEncodeAndDecodeDecimal("145.183000000000005", 16)) + assert.Equal(t, "-145.1830000000000050", mustEncodeAndDecodeDecimal("-145.183000000000005", 16)) + assert.Equal(t, "145.1830000000000090", mustEncodeAndDecodeDecimal("145.183000000000009", 16)) + assert.Equal(t, "-145.1830000000000090", mustEncodeAndDecodeDecimal("-145.183000000000009", 16)) + + assert.Equal(t, "-9063701308.217222135", mustEncodeAndDecodeDecimal("-9063701308.217222135", 9)) + assert.Equal(t, "-74961544796695.89960242", mustEncodeAndDecodeDecimal("-74961544796695.89960242", 8)) + + testCases := []struct { + name string + value string + scale int32 + }{ + { + name: "0 scale", + value: "5", + }, + { + name: "2 scale", + value: "23131319.99", + scale: 2, + }, + { + name: "5 scale", + value: "9.12345", + scale: 5, + }, + { + name: "negative number", + value: "-105.2813669", + scale: 7, + }, + // Longitude #1 + { + name: "long 1", + value: "-75.765611", + scale: 6, + }, + // Latitude #1 + { + name: "lat", + value: "40.0335495", + scale: 7, + }, + // Long #2 + { + name: "long 2", + value: "-119.65575", + scale: 5, + }, + { + name: "lat 2", + value: "36.3303", + scale: 4, + }, + { + name: "long 3", + value: "-81.76254098", + scale: 8, + }, + { + name: "amount", + value: "6408.355", + scale: 3, + }, + { + name: "total", + value: "1.05", + scale: 2, + }, + { + name: "negative number: 2^16 - 255", + value: "-65281", + scale: 0, + }, + { + name: "negative number: 2^16 - 1", + value: "-65535", + scale: 0, + }, + { + name: "number with a scale of 15", + value: "0.000022998904125", + scale: 15, + }, + { + name: "number with a scale of 15", + value: "145.183000000000000", + scale: 15, + }, + } + + for _, testCase := range testCases { + actual := mustEncodeAndDecodeDecimal(testCase.value, testCase.scale) + assert.Equal(t, testCase.value, actual, testCase.name) + } +} + func TestDecimalConverter_ToField(t *testing.T) { { // Without precision diff --git a/lib/debezium/converters/money.go b/lib/debezium/converters/money.go index d634010d..c1205062 100644 --- a/lib/debezium/converters/money.go +++ b/lib/debezium/converters/money.go @@ -56,5 +56,5 @@ func (m MoneyConverter) Convert(value any) (any, error) { return nil, fmt.Errorf(`unable to use %q as a money value: %w`, valString, err) } - return debezium.EncodeDecimalWithScale(decimal, int32(m.Scale())), nil + return encodeDecimalWithScale(decimal, int32(m.Scale())), nil }