Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Jun 25, 2024
1 parent c74b6ca commit 5af2604
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 64 deletions.
9 changes: 5 additions & 4 deletions lib/debezium/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,23 +91,24 @@ func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decima

// EncodeDecimal is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal].
// The scale of the value (which is the negated exponent of the decimal) is returned as the second argument.
func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32, error) {
func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) {
bigIntValue := decimal.Coeff.MathBigInt()
if decimal.Negative {
bigIntValue.Neg(bigIntValue)
}

return encodeBigInt(bigIntValue), -decimal.Exponent, nil
return encodeBigInt(bigIntValue), -decimal.Exponent
}

// EncodeDecimalWithScale is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal] using a specific
// scale.
func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) ([]byte, int32, error) {
func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte {
targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative.
if decimal.Exponent != targetExponent {
decimal = decimalWithNewExponent(decimal, targetExponent)
}
return EncodeDecimal(decimal)
out, _ := EncodeDecimal(decimal)
return out
}

// DecodeDecimal is used to decode `org.apache.kafka.connect.data.Decimal`.
Expand Down
123 changes: 63 additions & 60 deletions lib/debezium/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,12 @@ func TestDecodeBigInt(t *testing.T) {
}
}

func encodeAndDecodeDecimal(value string, scale int32) (string, error) {
decimal, _, err := new(apd.Decimal).SetString(value)
func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
return "", fmt.Errorf("unable to use %q as a floating-point number: %w", value, err)
panic(fmt.Errorf("unable to use %q as a floating-point number: %w", value, err))
}

bytes, _, err := EncodeDecimalWithScale(decimal, scale)
if err != nil {
return "", err
}
return DecodeDecimal(bytes, nil, int(scale)).String(), nil
}

func mustEncodeAndDecodeDecimal(value string, scale int32) string {
out, err := encodeAndDecodeDecimal(value, scale)
if err != nil {
panic(err)
}
return out
}

func mustParseDecimalFromString(in string) *apd.Decimal {
out, _, err := apd.NewFromString(in)
if err != nil {
panic(err)
}
return out
return decimal
}

func TestDecimalWithNewExponent(t *testing.T) {
Expand All @@ -78,56 +57,81 @@ func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0.0", decimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", decimalWithNewExponent(mustParseDecimalFromString("12.349"), -3).Text('f'))
assert.Equal(t, "12.349", decimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", decimalWithNewExponent(mustParseDecimalFromString("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", decimalWithNewExponent(mustParseDecimalFromString("12.349"), -5).Text('f'))
assert.Equal(t, "12.3490", decimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", decimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", decimalWithNewExponent(mustParseDecimalFromString("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", decimalWithNewExponent(mustParseDecimalFromString("12.349"), -1).Text('f'))
assert.Equal(t, "12", decimalWithNewExponent(mustParseDecimalFromString("12.349"), 0).Text('f'))
assert.Equal(t, "10", decimalWithNewExponent(mustParseDecimalFromString("12.349"), 1).Text('f'))
assert.Equal(t, "12.34", decimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", decimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", decimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", decimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f'))
}

func TestEncodeDecimal(t *testing.T) {
testValue := func(value string, expectedScale int32) {
bytes, scale := EncodeDecimal(mustParseDecimal(value))
result := DecodeDecimal(bytes, nil, int(scale)).String()
assert.Equal(t, result, value, value)
assert.Equal(t, expectedScale, scale, value)
}

testValue("0", 0)
testValue("0.0", 1)
testValue("0.00", 2)
testValue("0.00000", 5)
testValue("1", 0)
testValue("1.0", 1)
testValue("-1", 0)
testValue("-1.0", 1)
testValue("145.183000000000009", 15)
testValue("-145.183000000000009", 15)
}

func TestEncodeDecimalWithScale(t *testing.T) {
mustEncodeAndDecodeDecimalWithScale := func(value string, scale int32) string {
bytes := EncodeDecimalWithScale(mustParseDecimal(value), scale)
return DecodeDecimal(bytes, nil, int(scale)).String()
}

// Whole numbers:
for i := range 100_000 {
strValue := fmt.Sprint(i)
assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0))
assert.Equal(t, strValue, mustEncodeAndDecodeDecimalWithScale(strValue, 0))
if i != 0 {
strValue := "-" + strValue
assert.Equal(t, strValue, mustEncodeAndDecodeDecimal(strValue, 0))
assert.Equal(t, strValue, mustEncodeAndDecodeDecimalWithScale(strValue, 0))
}
}

// Scale of 15 that is equal to the amount of decimal places in the value:
assert.Equal(t, "145.183000000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 15))
assert.Equal(t, "-145.183000000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 15))
assert.Equal(t, "145.183000000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000000", 15))
assert.Equal(t, "-145.183000000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000000", 15))
// If scale is smaller than the amount of decimal places then the extra places should be truncated without rounding:
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000005", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000005", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000009", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000009", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000001", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000001", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimal("145.183000000000004", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimal("-145.183000000000004", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000000", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000005", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000005", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000009", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000009", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000000", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000001", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000001", 14))
assert.Equal(t, "145.18300000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000004", 14))
assert.Equal(t, "-145.18300000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000004", 14))
// If scale is larger than the amount of decimal places then the extra places should be padded with zeros:
assert.Equal(t, "145.1830000000000000", mustEncodeAndDecodeDecimal("145.183000000000000", 16))
assert.Equal(t, "-145.1830000000000000", mustEncodeAndDecodeDecimal("-145.183000000000000", 16))
assert.Equal(t, "145.1830000000000010", mustEncodeAndDecodeDecimal("145.183000000000001", 16))
assert.Equal(t, "-145.1830000000000010", mustEncodeAndDecodeDecimal("-145.183000000000001", 16))
assert.Equal(t, "145.1830000000000040", mustEncodeAndDecodeDecimal("145.183000000000004", 16))
assert.Equal(t, "-145.1830000000000040", mustEncodeAndDecodeDecimal("-145.183000000000004", 16))
assert.Equal(t, "145.1830000000000050", mustEncodeAndDecodeDecimal("145.183000000000005", 16))
assert.Equal(t, "-145.1830000000000050", mustEncodeAndDecodeDecimal("-145.183000000000005", 16))
assert.Equal(t, "145.1830000000000090", mustEncodeAndDecodeDecimal("145.183000000000009", 16))
assert.Equal(t, "-145.1830000000000090", mustEncodeAndDecodeDecimal("-145.183000000000009", 16))

assert.Equal(t, "-9063701308.217222135", mustEncodeAndDecodeDecimal("-9063701308.217222135", 9))
assert.Equal(t, "145.1830000000000000", mustEncodeAndDecodeDecimalWithScale("145.183000000000000", 16))
assert.Equal(t, "-145.1830000000000000", mustEncodeAndDecodeDecimalWithScale("-145.183000000000000", 16))
assert.Equal(t, "145.1830000000000010", mustEncodeAndDecodeDecimalWithScale("145.183000000000001", 16))
assert.Equal(t, "-145.1830000000000010", mustEncodeAndDecodeDecimalWithScale("-145.183000000000001", 16))
assert.Equal(t, "145.1830000000000040", mustEncodeAndDecodeDecimalWithScale("145.183000000000004", 16))
assert.Equal(t, "-145.1830000000000040", mustEncodeAndDecodeDecimalWithScale("-145.183000000000004", 16))
assert.Equal(t, "145.1830000000000050", mustEncodeAndDecodeDecimalWithScale("145.183000000000005", 16))
assert.Equal(t, "-145.1830000000000050", mustEncodeAndDecodeDecimalWithScale("-145.183000000000005", 16))
assert.Equal(t, "145.1830000000000090", mustEncodeAndDecodeDecimalWithScale("145.183000000000009", 16))
assert.Equal(t, "-145.1830000000000090", mustEncodeAndDecodeDecimalWithScale("-145.183000000000009", 16))

assert.Equal(t, "-9063701308.217222135", mustEncodeAndDecodeDecimalWithScale("-9063701308.217222135", 9))

testCases := []struct {
name string
Expand Down Expand Up @@ -214,8 +218,7 @@ func TestEncodeDecimal(t *testing.T) {
}

for _, testCase := range testCases {
actual, err := encodeAndDecodeDecimal(testCase.value, testCase.scale)
assert.NoError(t, err)
actual := mustEncodeAndDecodeDecimalWithScale(testCase.value, testCase.scale)
assert.Equal(t, testCase.value, actual, testCase.name)
}
}

0 comments on commit 5af2604

Please sign in to comment.