diff --git a/lib/debezium/decimal.go b/lib/debezium/decimal.go index e30a05ebb..b7d966899 100644 --- a/lib/debezium/decimal.go +++ b/lib/debezium/decimal.go @@ -24,6 +24,11 @@ func EncodeDecimal(value string, scale uint16) ([]byte, error) { return nil, fmt.Errorf("unable to use %q as a floating-point number", value) } + return encodeBigInt(bigIntValue), nil +} + +// encodeBigInt encodes a [big.Int] into a byte slice using two's complement. +func encodeBigInt(bigIntValue *big.Int) []byte { data := bigIntValue.Bytes() // [Bytes] returns the absolute value of the number. if bigIntValue.Sign() < 0 { // Convert to two's complement if the number is negative @@ -55,11 +60,11 @@ func EncodeDecimal(value string, scale uint16) ([]byte, error) { data = append([]byte{0x00}, data...) } } - return data, nil + return data } -// DecodeDecimal is used to decode `org.apache.kafka.connect.data.Decimal`. -func DecodeDecimal(data []byte, precision *int, scale int) *decimal.Decimal { +// decodeBigInt decodes a [big.Int] from a byte slice that has been encoded using two's complement. +func decodeBigInt(data []byte) *big.Int { bigInt := new(big.Int) // If the data represents a negative number, the sign bit will be set. @@ -77,8 +82,13 @@ func DecodeDecimal(data []byte, precision *int, scale int) *decimal.Decimal { bigInt.SetBytes(data) } + return bigInt +} + +// DecodeDecimal is used to decode `org.apache.kafka.connect.data.Decimal`. +func DecodeDecimal(data []byte, precision *int, scale int) *decimal.Decimal { // Convert the big integer to a big float - bigFloat := new(big.Float).SetInt(bigInt) + bigFloat := new(big.Float).SetInt(decodeBigInt(data)) // Compute divisor as 10^scale with big.Int's Exp, then convert to big.Float scaleInt := big.NewInt(int64(scale)) diff --git a/lib/debezium/decimal_test.go b/lib/debezium/decimal_test.go index d7e1a2268..80b6c4588 100644 --- a/lib/debezium/decimal_test.go +++ b/lib/debezium/decimal_test.go @@ -2,11 +2,42 @@ package debezium import ( "fmt" + "math/big" "testing" "github.com/stretchr/testify/assert" ) +func TestEncodeBigInt(t *testing.T) { + assert.Equal(t, []byte{}, encodeBigInt(big.NewInt(0))) + assert.Equal(t, []byte{0x01}, encodeBigInt(big.NewInt(1))) + assert.Equal(t, []byte{0xff}, encodeBigInt(big.NewInt(-1))) + assert.Equal(t, []byte{0x11}, encodeBigInt(big.NewInt(17))) + assert.Equal(t, []byte{0x7f}, encodeBigInt(big.NewInt(127))) + assert.Equal(t, []byte{0x81}, encodeBigInt(big.NewInt(-127))) + assert.Equal(t, []byte{0x00, 0x80}, encodeBigInt(big.NewInt(128))) + assert.Equal(t, []byte{0xff, 0x80}, encodeBigInt(big.NewInt(-128))) + assert.Equal(t, []byte{0x00, 0xff}, encodeBigInt(big.NewInt(255))) + assert.Equal(t, []byte{0x01, 0x00}, encodeBigInt(big.NewInt(256))) +} + +func TestDecodeBigInt(t *testing.T) { + assert.Equal(t, big.NewInt(0), decodeBigInt([]byte{})) + assert.Equal(t, big.NewInt(127), decodeBigInt([]byte{0x7f})) + assert.Equal(t, big.NewInt(-127), decodeBigInt([]byte{0x81})) + assert.Equal(t, big.NewInt(128), decodeBigInt([]byte{0x00, 0x80})) + assert.Equal(t, big.NewInt(-128), decodeBigInt([]byte{0xff, 0x80})) + + for i := range 100_000 { + bigInt := big.NewInt(int64(i)) + + assert.Equal(t, bigInt, decodeBigInt(encodeBigInt(bigInt))) + + negBigInt := bigInt.Neg(bigInt) + assert.Equal(t, negBigInt, decodeBigInt(encodeBigInt(negBigInt))) + } +} + func encodeAndDecodeDecimal(value string, scale uint16) (string, error) { bytes, err := EncodeDecimal(value, scale) if err != nil {