From 55c4ace9989866c1c70bb11dd21025f713c84f58 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Tue, 25 Jun 2024 23:57:21 -0700 Subject: [PATCH] Little cleanups --- lib/debezium/decimal.go | 10 ++-- lib/numbers/decimal.go | 39 ++++++++++++++ lib/numbers/decimal_test.go | 28 ++++++++++ lib/parquetutil/parse_values_test.go | 12 +---- lib/test/main.go | 76 ++++++++++++++++++++++++++++ lib/typing/decimal/decimal.go | 34 +------------ lib/typing/decimal/decimal_test.go | 29 ----------- lib/typing/values/string_test.go | 12 ++--- 8 files changed, 155 insertions(+), 85 deletions(-) create mode 100644 lib/numbers/decimal.go create mode 100644 lib/numbers/decimal_test.go create mode 100644 lib/test/main.go diff --git a/lib/debezium/decimal.go b/lib/debezium/decimal.go index ec76333a0..c1976a929 100644 --- a/lib/debezium/decimal.go +++ b/lib/debezium/decimal.go @@ -4,7 +4,7 @@ import ( "math/big" "slices" - "github.com/artie-labs/transfer/lib/typing/decimal" + "github.com/artie-labs/transfer/lib/numbers" "github.com/cockroachdb/apd/v3" ) @@ -75,12 +75,12 @@ func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) { // EncodeDecimalWithScale is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal]. // using a specific scale. -func EncodeDecimalWithScale(_decimal *apd.Decimal, scale int32) []byte { +func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte { targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative. - if _decimal.Exponent != targetExponent { - _decimal = decimal.DecimalWithNewExponent(_decimal, targetExponent) + if decimal.Exponent != targetExponent { + decimal = numbers.DecimalWithNewExponent(decimal, targetExponent) } - bytes, _ := EncodeDecimal(_decimal) + bytes, _ := EncodeDecimal(decimal) return bytes } diff --git a/lib/numbers/decimal.go b/lib/numbers/decimal.go new file mode 100644 index 000000000..e042fe447 --- /dev/null +++ b/lib/numbers/decimal.go @@ -0,0 +1,39 @@ +package numbers + +import "github.com/cockroachdb/apd/v3" + +// MustParseDecimal parses a string to an [apd.Decimal] or panics -- used for tests. +func MustParseDecimal(value string) *apd.Decimal { + decimal, _, err := apd.NewFromString(value) + if err != nil { + panic(err) + } + return decimal +} + +// DecimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent. +// If the new exponent is less precise then the extra digits will be truncated. +func DecimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal { + exponentDelta := newExponent - decimal.Exponent // Exponent is negative. + + if exponentDelta == 0 { + return new(apd.Decimal).Set(decimal) + } + + coefficient := new(apd.BigInt).Set(&decimal.Coeff) + + if exponentDelta < 0 { + multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil) + coefficient.Mul(coefficient, multiplier) + } else if exponentDelta > 0 { + divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil) + coefficient.Div(coefficient, divisor) + } + + return &apd.Decimal{ + Form: decimal.Form, + Negative: decimal.Negative, + Exponent: newExponent, + Coeff: *coefficient, + } +} diff --git a/lib/numbers/decimal_test.go b/lib/numbers/decimal_test.go new file mode 100644 index 000000000..a7ec21868 --- /dev/null +++ b/lib/numbers/decimal_test.go @@ -0,0 +1,28 @@ +package numbers + +import ( + "testing" + + "github.com/cockroachdb/apd/v3" + "github.com/stretchr/testify/assert" +) + +func TestDecimalWithNewExponent(t *testing.T) { + assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 0), 0).Text('f')) + assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 1), 1).Text('f')) + assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 100), 0).Text('f')) + assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 0), 1).Text('f')) + assert.Equal(t, "0.0", DecimalWithNewExponent(apd.New(0, 0), -1).Text('f')) + + // Same exponent: + assert.Equal(t, "12.349", DecimalWithNewExponent(MustParseDecimal("12.349"), -3).Text('f')) + // More precise exponent: + assert.Equal(t, "12.3490", DecimalWithNewExponent(MustParseDecimal("12.349"), -4).Text('f')) + assert.Equal(t, "12.34900", DecimalWithNewExponent(MustParseDecimal("12.349"), -5).Text('f')) + // Lest precise exponent: + // Extra digits should be truncated rather than rounded. + assert.Equal(t, "12.34", DecimalWithNewExponent(MustParseDecimal("12.349"), -2).Text('f')) + assert.Equal(t, "12.3", DecimalWithNewExponent(MustParseDecimal("12.349"), -1).Text('f')) + assert.Equal(t, "12", DecimalWithNewExponent(MustParseDecimal("12.349"), 0).Text('f')) + assert.Equal(t, "10", DecimalWithNewExponent(MustParseDecimal("12.349"), 1).Text('f')) +} diff --git a/lib/parquetutil/parse_values_test.go b/lib/parquetutil/parse_values_test.go index 9de756b55..457c99834 100644 --- a/lib/parquetutil/parse_values_test.go +++ b/lib/parquetutil/parse_values_test.go @@ -3,9 +3,9 @@ package parquetutil import ( "testing" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing/ext" - "github.com/cockroachdb/apd/v3" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -13,14 +13,6 @@ import ( "github.com/stretchr/testify/assert" ) -func mustParseDecimal(value string) *apd.Decimal { - decimal, _, err := apd.NewFromString(value) - if err != nil { - panic(err) - } - return decimal -} - func TestParseValue(t *testing.T) { eDecimal := typing.EDecimal eDecimal.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(30), 5, nil) @@ -74,7 +66,7 @@ func TestParseValue(t *testing.T) { }, { name: "decimal", - colVal: decimal.NewDecimal(ptr.ToInt(30), 5, mustParseDecimal("5000.2232")), + colVal: decimal.NewDecimal(ptr.ToInt(30), 5, numbers.MustParseDecimal("5000.2232")), colKind: columns.NewColumn("", eDecimal), expectedValue: "5000.22320", }, diff --git a/lib/test/main.go b/lib/test/main.go new file mode 100644 index 000000000..4e060b68b --- /dev/null +++ b/lib/test/main.go @@ -0,0 +1,76 @@ +package main + +import ( + "fmt" + "math/rand" + "strings" + + "github.com/artie-labs/transfer/lib/debezium" + "github.com/cockroachdb/apd/v3" +) + +func mustEncodeAndDecodeDecimal(decimal *apd.Decimal, scale int32) string { + bytes := debezium.EncodeDecimalWithScale(decimal, scale) + return debezium.DecodeDecimal(bytes, scale).Text('f') +} + +func randDigit() (byte, bool) { + offset := rand.Intn(10) + return byte(48 + offset), offset == 0 +} + +func generateNumberWithScale(maxDigitsBefore int, maxDigitsAfter int) (*apd.Decimal, int32) { + out := strings.Builder{} + + var wroteNonZero bool + for range rand.Intn(maxDigitsBefore + 1) { + digit, isZero := randDigit() + if isZero && !wroteNonZero { + continue + } + wroteNonZero = true + out.WriteByte(digit) + } + + if !wroteNonZero { + out.WriteRune('0') + } + + scale := rand.Intn(maxDigitsAfter + 1) + if scale > 0 { + out.WriteRune('.') + + for range scale { + digit, isZero := randDigit() + if !isZero { + wroteNonZero = true + } + out.WriteByte(digit) + } + } + + stringValue := out.String() + + if wroteNonZero && rand.Intn(2) == 1 { + stringValue = "-" + stringValue + } + + decimal, _, err := apd.NewFromString(stringValue) + if err != nil { + panic(err) + } + return decimal, -decimal.Exponent +} + +func main() { + for i := range 1000 { + fmt.Printf("Checking batch %d...\n", i) + for range 1_000_000 { + in, scale := generateNumberWithScale(30, 30) + out := mustEncodeAndDecodeDecimal(in, scale) + if in.Text('f') != out { + panic(fmt.Sprintf("Failed for %s -> %s", in.Text('f'), out)) + } + } + } +} diff --git a/lib/typing/decimal/decimal.go b/lib/typing/decimal/decimal.go index 83111cf15..68b903a90 100644 --- a/lib/typing/decimal/decimal.go +++ b/lib/typing/decimal/decimal.go @@ -3,6 +3,7 @@ package decimal import ( "fmt" + "github.com/artie-labs/transfer/lib/numbers" "github.com/artie-labs/transfer/lib/ptr" "github.com/cockroachdb/apd/v3" ) @@ -55,15 +56,11 @@ func (d *Decimal) String() string { targetExponent := -int32(d.scale) value := d.value if value.Exponent != targetExponent { - value = DecimalWithNewExponent(value, targetExponent) + value = numbers.DecimalWithNewExponent(value, targetExponent) } return value.Text('f') } -func (d *Decimal) Value() *apd.Decimal { - return d.value -} - // SnowflakeKind - is used to determine whether a NUMERIC data type should be a STRING or NUMERIC(p, s). func (d *Decimal) SnowflakeKind() string { return d.toKind(MaxPrecisionBeforeString, "STRING") @@ -90,30 +87,3 @@ func (d *Decimal) BigQueryKind() string { return "STRING" } - -// DecimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent. -// If the new exponent is less precise then the extra digits will be truncated. -func DecimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal { - exponentDelta := newExponent - decimal.Exponent // Exponent is negative. - - if exponentDelta == 0 { - return new(apd.Decimal).Set(decimal) - } - - coefficient := new(apd.BigInt).Set(&decimal.Coeff) - - if exponentDelta < 0 { - multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil) - coefficient.Mul(coefficient, multiplier) - } else if exponentDelta > 0 { - divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil) - coefficient.Div(coefficient, divisor) - } - - return &apd.Decimal{ - Form: decimal.Form, - Negative: decimal.Negative, - Exponent: newExponent, - Coeff: *coefficient, - } -} diff --git a/lib/typing/decimal/decimal_test.go b/lib/typing/decimal/decimal_test.go index e0841b2b8..a11b90acb 100644 --- a/lib/typing/decimal/decimal_test.go +++ b/lib/typing/decimal/decimal_test.go @@ -3,7 +3,6 @@ package decimal import ( "testing" - "github.com/cockroachdb/apd/v3" "github.com/stretchr/testify/assert" "github.com/artie-labs/transfer/lib/ptr" @@ -77,31 +76,3 @@ func TestDecimalKind(t *testing.T) { assert.Equal(t, testCase.ExpectedBigQueryKind, d.BigQueryKind(), testCase.Name) } } - -func mustParseDecimal(value string) *apd.Decimal { - decimal, _, err := apd.NewFromString(value) - if err != nil { - panic(err) - } - return decimal -} - -func TestDecimalWithNewExponent(t *testing.T) { - assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 0), 0).Text('f')) - assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 1), 1).Text('f')) - assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 100), 0).Text('f')) - assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 0), 1).Text('f')) - assert.Equal(t, "0.0", DecimalWithNewExponent(apd.New(0, 0), -1).Text('f')) - - // Same exponent: - assert.Equal(t, "12.349", DecimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f')) - // More precise exponent: - assert.Equal(t, "12.3490", DecimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f')) - assert.Equal(t, "12.34900", DecimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f')) - // Lest precise exponent: - // Extra digits should be truncated rather than rounded. - assert.Equal(t, "12.34", DecimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f')) - assert.Equal(t, "12.3", DecimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f')) - assert.Equal(t, "12", DecimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f')) - assert.Equal(t, "10", DecimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f')) -} diff --git a/lib/typing/values/string_test.go b/lib/typing/values/string_test.go index 3e7e49564..62a7ac9cc 100644 --- a/lib/typing/values/string_test.go +++ b/lib/typing/values/string_test.go @@ -20,14 +20,6 @@ func TestBooleanToBit(t *testing.T) { assert.Equal(t, 0, BooleanToBit(false)) } -func mustParseDecimal(value string) *apd.Decimal { - decimal, _, err := apd.NewFromString(value) - if err != nil { - panic(err) - } - return decimal -} - func TestToString(t *testing.T) { { // Nil value @@ -131,7 +123,9 @@ func TestToString(t *testing.T) { assert.Equal(t, "123.45", val) // Decimals - value := decimal.NewDecimal(ptr.ToInt(38), 2, mustParseDecimal("585692791691858.25")) + _decimal, _, err := apd.NewFromString("585692791691858.25") + assert.NoError(t, err) + value := decimal.NewDecimal(ptr.ToInt(38), 2, _decimal) val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil) assert.NoError(t, err) assert.Equal(t, "585692791691858.25", val)