diff --git a/clients/bigquery/storagewrite_test.go b/clients/bigquery/storagewrite_test.go index 23f66115f..9a5e688d2 100644 --- a/clients/bigquery/storagewrite_test.go +++ b/clients/bigquery/storagewrite_test.go @@ -129,9 +129,9 @@ func TestRowToMessage(t *testing.T) { "c_float_int32": int32(1234), "c_float_int64": int64(1234), "c_float_string": "4444.55555", - "c_numeric": decimal.NewDecimal(nil, numbers.MustParseDecimal("3.14159")), + "c_numeric": decimal.NewDecimal(numbers.MustParseDecimal("3.14159")), "c_string": "foo bar", - "c_string_decimal": decimal.NewDecimal(nil, numbers.MustParseDecimal("1.61803")), + "c_string_decimal": decimal.NewDecimal(numbers.MustParseDecimal("1.61803")), "c_time": ext.NewExtendedTime(time.Date(0, 0, 0, 4, 5, 6, 7, time.UTC), ext.TimeKindType, ""), "c_date": ext.NewExtendedTime(time.Date(2001, 2, 3, 0, 0, 0, 0, time.UTC), ext.DateKindType, ""), "c_datetime": ext.NewExtendedTime(time.Date(2001, 2, 3, 4, 5, 6, 7, time.UTC), ext.DateTimeKindType, ""), diff --git a/lib/debezium/types.go b/lib/debezium/types.go index f6f20753c..dadc4cbcf 100644 --- a/lib/debezium/types.go +++ b/lib/debezium/types.go @@ -8,7 +8,6 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/jsonutil" - "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/artie-labs/transfer/lib/maputil" @@ -237,7 +236,11 @@ func (f Field) DecodeDecimal(encoded []byte) (*decimal.Decimal, error) { return nil, fmt.Errorf("failed to get scale and/or precision: %w", err) } _decimal := DecodeDecimal(encoded, scale) - return decimal.NewDecimal(precision, _decimal), nil + if precision == nil { + return decimal.NewDecimal(_decimal), nil + } else { + return decimal.NewDecimalWithPrecision(_decimal, *precision), nil + } } func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error) { @@ -260,6 +263,5 @@ func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error if err != nil { return nil, err } - _decimal := DecodeDecimal(bytes, scale) - return decimal.NewDecimal(ptr.ToInt32(decimal.PrecisionNotSpecified), _decimal), nil + return decimal.NewDecimal(DecodeDecimal(bytes, scale)), nil } diff --git a/lib/debezium/types_test.go b/lib/debezium/types_test.go index 8b02c1f73..8dfdcbcfa 100644 --- a/lib/debezium/types_test.go +++ b/lib/debezium/types_test.go @@ -586,9 +586,9 @@ func TestField_DecodeDecimal(t *testing.T) { params: map[string]any{ "scale": "2", }, - expectedValue: "123456.98", - expectNilPtrPrecision: true, - expectedScale: 2, + expectedValue: "123456.98", + expectedPrecision: -1, + expectedScale: 2, }, } @@ -609,11 +609,7 @@ func TestField_DecodeDecimal(t *testing.T) { assert.NoError(t, err) assert.Equal(t, testCase.expectedValue, dec.String(), testCase.name) - if testCase.expectNilPtrPrecision { - assert.Nil(t, dec.Precision(), testCase.name) - } else { - assert.Equal(t, testCase.expectedPrecision, *dec.Precision(), testCase.name) - } + assert.Equal(t, testCase.expectedPrecision, dec.Precision(), testCase.name) assert.Equal(t, testCase.expectedScale, dec.Scale(), testCase.name) } } @@ -699,7 +695,7 @@ func TestField_DecodeDebeziumVariableDecimal(t *testing.T) { continue } - assert.Equal(t, int32(-1), *dec.Precision(), testCase.name) + assert.Equal(t, int32(-1), dec.Precision(), testCase.name) assert.Equal(t, testCase.expectedScale, dec.Scale(), testCase.name) assert.Equal(t, testCase.expectedValue, dec.String(), testCase.name) } diff --git a/lib/parquetutil/parse_values_test.go b/lib/parquetutil/parse_values_test.go index 403af9127..ac3cb7183 100644 --- a/lib/parquetutil/parse_values_test.go +++ b/lib/parquetutil/parse_values_test.go @@ -4,7 +4,6 @@ import ( "testing" "github.com/artie-labs/transfer/lib/numbers" - "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing/ext" "github.com/artie-labs/transfer/lib/typing" @@ -66,7 +65,7 @@ func TestParseValue(t *testing.T) { }, { name: "decimal", - colVal: decimal.NewDecimal(ptr.ToInt32(30), numbers.MustParseDecimal("5000.22320")), + colVal: decimal.NewDecimalWithPrecision(numbers.MustParseDecimal("5000.22320"), 30), colKind: columns.NewColumn("", eDecimal), expectedValue: "5000.22320", }, diff --git a/lib/typing/decimal/decimal.go b/lib/typing/decimal/decimal.go index 6fae8c9b6..8e3af1b5b 100644 --- a/lib/typing/decimal/decimal.go +++ b/lib/typing/decimal/decimal.go @@ -1,13 +1,12 @@ package decimal import ( - "github.com/artie-labs/transfer/lib/ptr" "github.com/cockroachdb/apd/v3" ) // Decimal is Artie's wrapper around [*apd.Decimal] which can store large numbers w/ no precision loss. type Decimal struct { - precision *int32 + precision int32 value *apd.Decimal } @@ -19,16 +18,14 @@ const ( MaxPrecisionBeforeString int32 = 38 ) -func NewDecimal(precision *int32, value *apd.Decimal) *Decimal { - if precision != nil { - scale := -value.Exponent - if scale > *precision && *precision != PrecisionNotSpecified { - // Note: -1 precision means it's not specified. +func NewDecimalWithPrecision(value *apd.Decimal, precision int32) *Decimal { + scale := -value.Exponent + if scale > precision && precision != PrecisionNotSpecified { + // Note: -1 precision means it's not specified. - // This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6) - // We are setting precision to be scale + 1 to account for the leading zero for decimal numbers. - precision = ptr.ToInt32(scale + 1) - } + // This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6) + // We are setting precision to be scale + 1 to account for the leading zero for decimal numbers. + precision = scale + 1 } return &Decimal{ @@ -37,14 +34,22 @@ func NewDecimal(precision *int32, value *apd.Decimal) *Decimal { } } +func NewDecimal(value *apd.Decimal) *Decimal { + return NewDecimalWithPrecision(value, PrecisionNotSpecified) +} + func (d *Decimal) Scale() int32 { return -d.value.Exponent } -func (d *Decimal) Precision() *int32 { +func (d *Decimal) Precision() int32 { return d.precision } +func (d *Decimal) Value() *apd.Decimal { + return d.value +} + // String() is used to override fmt.Sprint(val), where val type is *decimal.Decimal // This is particularly useful for Snowflake because we're writing all the values as STRINGS into TSV format. // This function guarantees backwards compatibility. @@ -53,9 +58,5 @@ func (d *Decimal) String() string { } func (d *Decimal) Details() DecimalDetails { - precision := PrecisionNotSpecified - if d.precision != nil { - precision = *d.precision - } - return DecimalDetails{scale: d.Scale(), precision: precision} + return DecimalDetails{scale: d.Scale(), precision: d.precision} } diff --git a/lib/typing/decimal/decimal_test.go b/lib/typing/decimal/decimal_test.go index a2172cbc5..4aba71760 100644 --- a/lib/typing/decimal/decimal_test.go +++ b/lib/typing/decimal/decimal_test.go @@ -4,52 +4,48 @@ import ( "testing" "github.com/artie-labs/transfer/lib/numbers" - "github.com/artie-labs/transfer/lib/ptr" "github.com/stretchr/testify/assert" ) func TestNewDecimal(t *testing.T) { - // Nil precision: - assert.Equal(t, "0", NewDecimal(nil, numbers.MustParseDecimal("0")).String()) + assert.Equal(t, "0", NewDecimal(numbers.MustParseDecimal("0")).String()) + assert.Equal(t, "1", NewDecimal(numbers.MustParseDecimal("1")).String()) + assert.Equal(t, "12.34", NewDecimal(numbers.MustParseDecimal("12.34")).String()) +} + +func TestNewDecimalWithPrecision(t *testing.T) { // Precision = -1 (PrecisionNotSpecified): - assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), PrecisionNotSpecified).Details()) // Precision = scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: 2}, NewDecimal(ptr.ToInt32(2), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 2}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 2).Details()) // Precision < scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: 3}, NewDecimal(ptr.ToInt32(1), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 3}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 1).Details()) // Precision > scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: 4}, NewDecimal(ptr.ToInt32(4), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 4}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 4).Details()) } func TestDecimal_Scale(t *testing.T) { - assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("0")).Scale()) - assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("12345")).Scale()) - assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("12300")).Scale()) - assert.Equal(t, int32(1), NewDecimal(nil, numbers.MustParseDecimal("12300.0")).Scale()) - assert.Equal(t, int32(2), NewDecimal(nil, numbers.MustParseDecimal("12300.00")).Scale()) - assert.Equal(t, int32(2), NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Scale()) - assert.Equal(t, int32(3), NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Scale()) + assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("0")).Scale()) + assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("12345")).Scale()) + assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("12300")).Scale()) + assert.Equal(t, int32(1), NewDecimal(numbers.MustParseDecimal("12300.0")).Scale()) + assert.Equal(t, int32(2), NewDecimal(numbers.MustParseDecimal("12300.00")).Scale()) + assert.Equal(t, int32(2), NewDecimal(numbers.MustParseDecimal("12345.12")).Scale()) + assert.Equal(t, int32(3), NewDecimal(numbers.MustParseDecimal("-12345.123")).Scale()) } func TestDecimal_Details(t *testing.T) { - // Nil precision: - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Details()) - // -1 precision (PrecisionNotSpecified): - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12345.123")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(numbers.MustParseDecimal("-12345.123")).Details()) // 10 precision: - assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12345.123")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("0"), 10).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("12345"), 10).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("-12"), 10).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("12345.12"), 10).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("-12345.123"), 10).Details()) } diff --git a/lib/typing/values/string_test.go b/lib/typing/values/string_test.go index 7ec8d0cce..85fb782d5 100644 --- a/lib/typing/values/string_test.go +++ b/lib/typing/values/string_test.go @@ -8,7 +8,6 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/numbers" - "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/decimal" @@ -123,7 +122,7 @@ func TestToString(t *testing.T) { assert.Equal(t, "123.45", val) // Decimals - value := decimal.NewDecimal(ptr.ToInt32(38), numbers.MustParseDecimal("585692791691858.25")) + value := decimal.NewDecimalWithPrecision(numbers.MustParseDecimal("585692791691858.25"), 38) val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil) assert.NoError(t, err) assert.Equal(t, "585692791691858.25", val)