From 9d1f5abe53e5e66bd9c6ccdf8f307398778befda Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 26 Jun 2024 14:08:35 -0700 Subject: [PATCH] [debezium] Return `apd.Decimal` from `DecodeDecimal` (#765) --- clients/bigquery/storagewrite_test.go | 6 +-- clients/shared/utils.go | 1 + lib/debezium/decimal.go | 21 +++------- lib/debezium/decimal_test.go | 9 +++-- lib/debezium/types.go | 6 ++- lib/debezium/types_bench_test.go | 2 +- lib/parquetutil/parse_values_test.go | 4 +- lib/typing/decimal/decimal.go | 24 +++++++----- lib/typing/decimal/decimal_test.go | 55 +++++++++++++++++++++++++++ lib/typing/values/string_test.go | 4 +- 10 files changed, 94 insertions(+), 38 deletions(-) create mode 100644 lib/typing/decimal/decimal_test.go diff --git a/clients/bigquery/storagewrite_test.go b/clients/bigquery/storagewrite_test.go index 7101e7718..23f66115f 100644 --- a/clients/bigquery/storagewrite_test.go +++ b/clients/bigquery/storagewrite_test.go @@ -2,11 +2,11 @@ package bigquery import ( "encoding/json" - "math/big" "testing" "time" "cloud.google.com/go/bigquery/storage/apiv1/storagepb" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/decimal" @@ -129,9 +129,9 @@ func TestRowToMessage(t *testing.T) { "c_float_int32": int32(1234), "c_float_int64": int64(1234), "c_float_string": "4444.55555", - "c_numeric": decimal.NewDecimal(nil, 5, big.NewFloat(3.1415926)), + "c_numeric": decimal.NewDecimal(nil, numbers.MustParseDecimal("3.14159")), "c_string": "foo bar", - "c_string_decimal": decimal.NewDecimal(nil, 5, big.NewFloat(1.618033)), + "c_string_decimal": decimal.NewDecimal(nil, numbers.MustParseDecimal("1.61803")), "c_time": ext.NewExtendedTime(time.Date(0, 0, 0, 4, 5, 6, 7, time.UTC), ext.TimeKindType, ""), "c_date": ext.NewExtendedTime(time.Date(2001, 2, 3, 0, 0, 0, 0, time.UTC), ext.DateKindType, ""), "c_datetime": ext.NewExtendedTime(time.Date(2001, 2, 3, 4, 5, 6, 7, time.UTC), ext.DateTimeKindType, ""), diff --git a/clients/shared/utils.go b/clients/shared/utils.go index 5f479fe79..1b7022053 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -44,6 +44,7 @@ func DefaultValue(column columns.Column, dialect sql.Dialect, additionalDateFmts return nil, fmt.Errorf("colVal is not type *decimal.Decimal") } + // TODO: Call [String] instead. return val.Value(), nil case typing.String.Kind: return sql.QuoteLiteral(fmt.Sprint(column.DefaultValue())), nil diff --git a/lib/debezium/decimal.go b/lib/debezium/decimal.go index 2e745e62d..7e46c997c 100644 --- a/lib/debezium/decimal.go +++ b/lib/debezium/decimal.go @@ -4,7 +4,6 @@ import ( "math/big" "slices" - "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/cockroachdb/apd/v3" ) @@ -89,7 +88,7 @@ func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decima } } -// EncodeDecimal is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. +// EncodeDecimal is used to encode a [*apd.Decimal] to `org.apache.kafka.connect.data.Decimal`. // The scale of the value (which is the negated exponent of the decimal) is returned as the second argument. func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) { bigIntValue := decimal.Coeff.MathBigInt() @@ -100,7 +99,7 @@ func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) { return encodeBigInt(bigIntValue), -decimal.Exponent } -// EncodeDecimalWithScale is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. +// EncodeDecimalWithScale is used to encode a [*apd.Decimal] to `org.apache.kafka.connect.data.Decimal` // using a specific scale. func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte { targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative. @@ -112,17 +111,7 @@ func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte { } // DecodeDecimal is used to decode `org.apache.kafka.connect.data.Decimal`. -func DecodeDecimal(data []byte, precision *int, scale int) *decimal.Decimal { - // Convert the big integer to a big float - bigFloat := new(big.Float).SetInt(decodeBigInt(data)) - - // Compute divisor as 10^scale with big.Int's Exp, then convert to big.Float - scaleInt := big.NewInt(int64(scale)) - ten := big.NewInt(10) - divisorInt := new(big.Int).Exp(ten, scaleInt, nil) - divisorFloat := new(big.Float).SetInt(divisorInt) - - // Perform the division - bigFloat.Quo(bigFloat, divisorFloat) - return decimal.NewDecimal(precision, scale, bigFloat) +func DecodeDecimal(data []byte, scale int32) *apd.Decimal { + bigInt := new(apd.BigInt).SetMathBigInt(decodeBigInt(data)) + return apd.NewWithBigInt(bigInt, -scale) } diff --git a/lib/debezium/decimal_test.go b/lib/debezium/decimal_test.go index e74bd7789..931468710 100644 --- a/lib/debezium/decimal_test.go +++ b/lib/debezium/decimal_test.go @@ -65,9 +65,11 @@ func TestDecimalWithNewExponent(t *testing.T) { func TestEncodeDecimal(t *testing.T) { testEncodeDecimal := func(value string, expectedScale int32) { bytes, scale := EncodeDecimal(numbers.MustParseDecimal(value)) - actual := DecodeDecimal(bytes, nil, int(scale)).String() - assert.Equal(t, value, actual, value) assert.Equal(t, expectedScale, scale, value) + + actual := DecodeDecimal(bytes, scale) + assert.Equal(t, value, actual.Text('f'), value) + assert.Equal(t, expectedScale, -actual.Exponent, value) } testEncodeDecimal("0", 0) @@ -85,7 +87,7 @@ func TestEncodeDecimal(t *testing.T) { func TestEncodeDecimalWithScale(t *testing.T) { mustEncodeAndDecodeDecimal := func(value string, scale int32) string { bytes := EncodeDecimalWithScale(numbers.MustParseDecimal(value), scale) - return DecodeDecimal(bytes, nil, int(scale)).String() + return DecodeDecimal(bytes, scale).String() } // Whole numbers: @@ -125,6 +127,7 @@ func TestEncodeDecimalWithScale(t *testing.T) { assert.Equal(t, "-145.1830000000000090", mustEncodeAndDecodeDecimal("-145.183000000000009", 16)) assert.Equal(t, "-9063701308.217222135", mustEncodeAndDecodeDecimal("-9063701308.217222135", 9)) + assert.Equal(t, "-74961544796695.89960242", mustEncodeAndDecodeDecimal("-74961544796695.89960242", 8)) testCases := []struct { name string diff --git a/lib/debezium/types.go b/lib/debezium/types.go index 655156815..c7af11e9c 100644 --- a/lib/debezium/types.go +++ b/lib/debezium/types.go @@ -236,7 +236,8 @@ func (f Field) DecodeDecimal(encoded []byte) (*decimal.Decimal, error) { if err != nil { return nil, fmt.Errorf("failed to get scale and/or precision: %w", err) } - return DecodeDecimal(encoded, precision, scale), nil + _decimal := DecodeDecimal(encoded, int32(scale)) + return decimal.NewDecimal(precision, _decimal), nil } func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error) { @@ -259,5 +260,6 @@ func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error if err != nil { return nil, err } - return DecodeDecimal(bytes, ptr.ToInt(decimal.PrecisionNotSpecified), scale), nil + _decimal := DecodeDecimal(bytes, int32(scale)) + return decimal.NewDecimal(ptr.ToInt(decimal.PrecisionNotSpecified), _decimal), nil } diff --git a/lib/debezium/types_bench_test.go b/lib/debezium/types_bench_test.go index 00e763e9d..c014c9fb7 100644 --- a/lib/debezium/types_bench_test.go +++ b/lib/debezium/types_bench_test.go @@ -19,7 +19,7 @@ func BenchmarkDecodeDecimal_P64_S10(b *testing.B) { assert.NoError(b, err) dec, err := field.DecodeDecimal(bytes) assert.NoError(b, err) - assert.Equal(b, "123456789012345678901234567890123456789012345678901234.1234567889", dec.Value()) + assert.Equal(b, "123456789012345678901234567890123456789012345678901234.1234567890", dec.String()) require.NoError(b, err) } } diff --git a/lib/parquetutil/parse_values_test.go b/lib/parquetutil/parse_values_test.go index 29b99ee6b..5174ddb7d 100644 --- a/lib/parquetutil/parse_values_test.go +++ b/lib/parquetutil/parse_values_test.go @@ -1,9 +1,9 @@ package parquetutil import ( - "math/big" "testing" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing/ext" @@ -66,7 +66,7 @@ func TestParseValue(t *testing.T) { }, { name: "decimal", - colVal: decimal.NewDecimal(ptr.ToInt(30), 5, big.NewFloat(5000.2232)), + colVal: decimal.NewDecimal(ptr.ToInt(30), numbers.MustParseDecimal("5000.22320")), colKind: columns.NewColumn("", eDecimal), expectedValue: "5000.22320", }, diff --git a/lib/typing/decimal/decimal.go b/lib/typing/decimal/decimal.go index 89625b2bf..00ba25949 100644 --- a/lib/typing/decimal/decimal.go +++ b/lib/typing/decimal/decimal.go @@ -1,16 +1,17 @@ package decimal import ( + "log/slog" "math/big" "github.com/artie-labs/transfer/lib/ptr" + "github.com/cockroachdb/apd/v3" ) -// Decimal is Artie's wrapper around *big.Float which can store large numbers w/ no precision loss. +// Decimal is Artie's wrapper around [*apd.Decimal] which can store large numbers w/ no precision loss. type Decimal struct { - scale int precision *int - value *big.Float + value *apd.Decimal } const ( @@ -21,8 +22,9 @@ const ( MaxPrecisionBeforeString = 38 ) -func NewDecimal(precision *int, scale int, value *big.Float) *Decimal { +func NewDecimal(precision *int, value *apd.Decimal) *Decimal { if precision != nil { + scale := int(-value.Exponent) if scale > *precision && *precision != -1 { // Note: -1 precision means it's not specified. @@ -33,14 +35,13 @@ func NewDecimal(precision *int, scale int, value *big.Float) *Decimal { } return &Decimal{ - scale: scale, precision: precision, value: value, } } func (d *Decimal) Scale() int { - return d.scale + return int(-d.value.Exponent) } func (d *Decimal) Precision() *int { @@ -51,7 +52,7 @@ func (d *Decimal) Precision() *int { // This is particularly useful for Snowflake because we're writing all the values as STRINGS into TSV format. // This function guarantees backwards compatibility. func (d *Decimal) String() string { - return d.value.Text('f', d.scale) + return d.value.Text('f') } func (d *Decimal) Value() any { @@ -62,9 +63,14 @@ func (d *Decimal) Value() any { } // Depending on the precision, we will want to convert value to STRING or keep as a FLOAT. - return d.value + // TODO: [Value] is only called in one place, look into calling [String] instead. + if out, ok := new(big.Float).SetString(d.String()); ok { + return out + } + slog.Error("Failed to convert apd.Decimal to big.Float", slog.String("value", d.String())) + return d.String() } func (d *Decimal) Details() DecimalDetails { - return DecimalDetails{scale: d.scale, precision: d.precision} + return DecimalDetails{scale: d.Scale(), precision: d.precision} } diff --git a/lib/typing/decimal/decimal_test.go b/lib/typing/decimal/decimal_test.go new file mode 100644 index 000000000..666761b20 --- /dev/null +++ b/lib/typing/decimal/decimal_test.go @@ -0,0 +1,55 @@ +package decimal + +import ( + "testing" + + "github.com/artie-labs/transfer/lib/numbers" + "github.com/artie-labs/transfer/lib/ptr" + "github.com/stretchr/testify/assert" +) + +func TestNewDecimal(t *testing.T) { + // Nil precision: + assert.Equal(t, "0", NewDecimal(nil, numbers.MustParseDecimal("0")).String()) + // Precision = -1 (PrecisionNotSpecified): + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("12.34")).Details()) + // Precision = scale: + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(2)}, NewDecimal(ptr.ToInt(2), numbers.MustParseDecimal("12.34")).Details()) + // Precision < scale: + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(3)}, NewDecimal(ptr.ToInt(1), numbers.MustParseDecimal("12.34")).Details()) + // Precision > scale: + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(4)}, NewDecimal(ptr.ToInt(4), numbers.MustParseDecimal("12.34")).Details()) +} + +func TestDecimal_Scale(t *testing.T) { + assert.Equal(t, 0, NewDecimal(nil, numbers.MustParseDecimal("0")).Scale()) + assert.Equal(t, 0, NewDecimal(nil, numbers.MustParseDecimal("12345")).Scale()) + assert.Equal(t, 0, NewDecimal(nil, numbers.MustParseDecimal("12300")).Scale()) + assert.Equal(t, 1, NewDecimal(nil, numbers.MustParseDecimal("12300.0")).Scale()) + assert.Equal(t, 2, NewDecimal(nil, numbers.MustParseDecimal("12300.00")).Scale()) + assert.Equal(t, 2, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Scale()) + assert.Equal(t, 3, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Scale()) +} + +func TestDecimal_Details(t *testing.T) { + // Nil precision: + assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2}, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3}, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Details()) + + // -1 precision (PrecisionNotSpecified): + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: ptr.ToInt(-1)}, NewDecimal(ptr.ToInt(-1), numbers.MustParseDecimal("-12345.123")).Details()) + + // 10 precision: + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(10)}, NewDecimal(ptr.ToInt(10), numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(10)}, NewDecimal(ptr.ToInt(10), numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt(10)}, NewDecimal(ptr.ToInt(10), numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt(10)}, NewDecimal(ptr.ToInt(10), numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: ptr.ToInt(10)}, NewDecimal(ptr.ToInt(10), numbers.MustParseDecimal("-12345.123")).Details()) +} diff --git a/lib/typing/values/string_test.go b/lib/typing/values/string_test.go index 46dff4e5f..945a8f756 100644 --- a/lib/typing/values/string_test.go +++ b/lib/typing/values/string_test.go @@ -1,13 +1,13 @@ package values import ( - "math/big" "testing" "time" "github.com/stretchr/testify/assert" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -123,7 +123,7 @@ func TestToString(t *testing.T) { assert.Equal(t, "123.45", val) // Decimals - value := decimal.NewDecimal(ptr.ToInt(38), 2, big.NewFloat(585692791691858.25)) + value := decimal.NewDecimal(ptr.ToInt(38), numbers.MustParseDecimal("585692791691858.25")) val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil) assert.NoError(t, err) assert.Equal(t, "585692791691858.25", val)