diff --git a/lib/debezium/decimal.go b/lib/debezium/decimal.go index b7d531518..2e745e62d 100644 --- a/lib/debezium/decimal.go +++ b/lib/debezium/decimal.go @@ -62,7 +62,7 @@ func decodeBigInt(data []byte) *big.Int { return bigInt } -// decimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent. +// decimalWithNewExponent takes a [*apd.Decimal] and returns a new [*apd.Decimal] with a the given exponent. // If the new exponent is less precise then the extra digits will be truncated. func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal { exponentDelta := newExponent - decimal.Exponent // Exponent is negative. @@ -89,7 +89,7 @@ func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decima } } -// EncodeDecimal is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. +// EncodeDecimal is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. // The scale of the value (which is the negated exponent of the decimal) is returned as the second argument. func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) { bigIntValue := decimal.Coeff.MathBigInt() @@ -100,7 +100,7 @@ func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) { return encodeBigInt(bigIntValue), -decimal.Exponent } -// EncodeDecimalWithScale is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal] +// EncodeDecimalWithScale is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. // using a specific scale. func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte { targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative. diff --git a/lib/debezium/decimal_test.go b/lib/debezium/decimal_test.go index e3642dbed..e74bd7789 100644 --- a/lib/debezium/decimal_test.go +++ b/lib/debezium/decimal_test.go @@ -6,6 +6,7 @@ import ( "math/big" "testing" + "github.com/artie-labs/transfer/lib/numbers" "github.com/cockroachdb/apd/v3" "github.com/stretchr/testify/assert" ) @@ -41,14 +42,6 @@ func TestDecodeBigInt(t *testing.T) { } } -func mustParseDecimal(value string) *apd.Decimal { - decimal, _, err := apd.NewFromString(value) - if err != nil { - panic(err) - } - return decimal -} - func TestDecimalWithNewExponent(t *testing.T) { assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 0), 0).Text('f')) assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 1), 1).Text('f')) @@ -57,21 +50,21 @@ func TestDecimalWithNewExponent(t *testing.T) { assert.Equal(t, "0.0", decimalWithNewExponent(apd.New(0, 0), -1).Text('f')) // Same exponent: - assert.Equal(t, "12.349", decimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f')) + assert.Equal(t, "12.349", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -3).Text('f')) // More precise exponent: - assert.Equal(t, "12.3490", decimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f')) - assert.Equal(t, "12.34900", decimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f')) + assert.Equal(t, "12.3490", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -4).Text('f')) + assert.Equal(t, "12.34900", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -5).Text('f')) // Lest precise exponent: // Extra digits should be truncated rather than rounded. - assert.Equal(t, "12.34", decimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f')) - assert.Equal(t, "12.3", decimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f')) - assert.Equal(t, "12", decimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f')) - assert.Equal(t, "10", decimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f')) + assert.Equal(t, "12.34", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -2).Text('f')) + assert.Equal(t, "12.3", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -1).Text('f')) + assert.Equal(t, "12", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 0).Text('f')) + assert.Equal(t, "10", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 1).Text('f')) } func TestEncodeDecimal(t *testing.T) { testEncodeDecimal := func(value string, expectedScale int32) { - bytes, scale := EncodeDecimal(mustParseDecimal(value)) + bytes, scale := EncodeDecimal(numbers.MustParseDecimal(value)) actual := DecodeDecimal(bytes, nil, int(scale)).String() assert.Equal(t, value, actual, value) assert.Equal(t, expectedScale, scale, value) @@ -91,7 +84,7 @@ func TestEncodeDecimal(t *testing.T) { func TestEncodeDecimalWithScale(t *testing.T) { mustEncodeAndDecodeDecimal := func(value string, scale int32) string { - bytes := EncodeDecimalWithScale(mustParseDecimal(value), scale) + bytes := EncodeDecimalWithScale(numbers.MustParseDecimal(value), scale) return DecodeDecimal(bytes, nil, int(scale)).String() } diff --git a/lib/numbers/numbers.go b/lib/numbers/numbers.go index f414bbd83..d4ea8f4fd 100644 --- a/lib/numbers/numbers.go +++ b/lib/numbers/numbers.go @@ -1,6 +1,17 @@ package numbers +import "github.com/cockroachdb/apd/v3" + // BetweenEq - Looks something like this. start <= number <= end func BetweenEq[T int | int32 | int64](start, end, number T) bool { return number >= start && number <= end } + +// MustParseDecimal parses a string to a [*apd.Decimal] or panics -- used for tests. +func MustParseDecimal(value string) *apd.Decimal { + decimal, _, err := apd.NewFromString(value) + if err != nil { + panic(err) + } + return decimal +}