Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] Remove pointer from Decimal precision #772

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clients/bigquery/storagewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ func TestRowToMessage(t *testing.T) {
"c_float_int32": int32(1234),
"c_float_int64": int64(1234),
"c_float_string": "4444.55555",
"c_numeric": decimal.NewDecimal(nil, numbers.MustParseDecimal("3.14159")),
"c_numeric": decimal.NewDecimal(numbers.MustParseDecimal("3.14159")),
"c_string": "foo bar",
"c_string_decimal": decimal.NewDecimal(nil, numbers.MustParseDecimal("1.61803")),
"c_string_decimal": decimal.NewDecimal(numbers.MustParseDecimal("1.61803")),
"c_time": ext.NewExtendedTime(time.Date(0, 0, 0, 4, 5, 6, 7, time.UTC), ext.TimeKindType, ""),
"c_date": ext.NewExtendedTime(time.Date(2001, 2, 3, 0, 0, 0, 0, time.UTC), ext.DateKindType, ""),
"c_datetime": ext.NewExtendedTime(time.Date(2001, 2, 3, 4, 5, 6, 7, time.UTC), ext.DateTimeKindType, ""),
Expand Down
10 changes: 6 additions & 4 deletions lib/debezium/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/jsonutil"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing/decimal"

"github.com/artie-labs/transfer/lib/maputil"
Expand Down Expand Up @@ -237,7 +236,11 @@ func (f Field) DecodeDecimal(encoded []byte) (*decimal.Decimal, error) {
return nil, fmt.Errorf("failed to get scale and/or precision: %w", err)
}
_decimal := DecodeDecimal(encoded, scale)
return decimal.NewDecimal(precision, _decimal), nil
if precision == nil {
return decimal.NewDecimal(_decimal), nil
} else {
return decimal.NewDecimalWithPrecision(_decimal, *precision), nil
}
}

func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error) {
Expand All @@ -260,6 +263,5 @@ func (f Field) DecodeDebeziumVariableDecimal(value any) (*decimal.Decimal, error
if err != nil {
return nil, err
}
_decimal := DecodeDecimal(bytes, scale)
return decimal.NewDecimal(ptr.ToInt32(decimal.PrecisionNotSpecified), _decimal), nil
return decimal.NewDecimal(DecodeDecimal(bytes, scale)), nil
}
14 changes: 5 additions & 9 deletions lib/debezium/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,9 @@ func TestField_DecodeDecimal(t *testing.T) {
params: map[string]any{
"scale": "2",
},
expectedValue: "123456.98",
expectNilPtrPrecision: true,
expectedScale: 2,
expectedValue: "123456.98",
expectedPrecision: -1,
expectedScale: 2,
},
}

Expand All @@ -609,11 +609,7 @@ func TestField_DecodeDecimal(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, testCase.expectedValue, dec.String(), testCase.name)

if testCase.expectNilPtrPrecision {
assert.Nil(t, dec.Precision(), testCase.name)
} else {
assert.Equal(t, testCase.expectedPrecision, *dec.Precision(), testCase.name)
}
assert.Equal(t, testCase.expectedPrecision, dec.Precision(), testCase.name)
assert.Equal(t, testCase.expectedScale, dec.Scale(), testCase.name)
}
}
Expand Down Expand Up @@ -699,7 +695,7 @@ func TestField_DecodeDebeziumVariableDecimal(t *testing.T) {
continue
}

assert.Equal(t, int32(-1), *dec.Precision(), testCase.name)
assert.Equal(t, int32(-1), dec.Precision(), testCase.name)
assert.Equal(t, testCase.expectedScale, dec.Scale(), testCase.name)
assert.Equal(t, testCase.expectedValue, dec.String(), testCase.name)
}
Expand Down
3 changes: 1 addition & 2 deletions lib/parquetutil/parse_values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"testing"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing/ext"

"github.com/artie-labs/transfer/lib/typing"
Expand Down Expand Up @@ -66,7 +65,7 @@ func TestParseValue(t *testing.T) {
},
{
name: "decimal",
colVal: decimal.NewDecimal(ptr.ToInt32(30), numbers.MustParseDecimal("5000.22320")),
colVal: decimal.NewDecimalWithPrecision(numbers.MustParseDecimal("5000.22320"), 30),
colKind: columns.NewColumn("", eDecimal),
expectedValue: "5000.22320",
},
Expand Down
35 changes: 18 additions & 17 deletions lib/typing/decimal/decimal.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package decimal

import (
"github.com/artie-labs/transfer/lib/ptr"
"github.com/cockroachdb/apd/v3"
)

// Decimal is Artie's wrapper around [*apd.Decimal] which can store large numbers w/ no precision loss.
type Decimal struct {
precision *int32
precision int32
value *apd.Decimal
}

Expand All @@ -19,16 +18,14 @@ const (
MaxPrecisionBeforeString int32 = 38
)

func NewDecimal(precision *int32, value *apd.Decimal) *Decimal {
if precision != nil {
scale := -value.Exponent
if scale > *precision && *precision != PrecisionNotSpecified {
// Note: -1 precision means it's not specified.
func NewDecimalWithPrecision(value *apd.Decimal, precision int32) *Decimal {
scale := -value.Exponent
if scale > precision && precision != PrecisionNotSpecified {
// Note: -1 precision means it's not specified.

// This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6)
// We are setting precision to be scale + 1 to account for the leading zero for decimal numbers.
precision = ptr.ToInt32(scale + 1)
}
// This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6)
// We are setting precision to be scale + 1 to account for the leading zero for decimal numbers.
precision = scale + 1
}

return &Decimal{
Expand All @@ -37,14 +34,22 @@ func NewDecimal(precision *int32, value *apd.Decimal) *Decimal {
}
}

func NewDecimal(value *apd.Decimal) *Decimal {
return NewDecimalWithPrecision(value, PrecisionNotSpecified)
}

func (d *Decimal) Scale() int32 {
return -d.value.Exponent
}

func (d *Decimal) Precision() *int32 {
func (d *Decimal) Precision() int32 {
return d.precision
}

func (d *Decimal) Value() *apd.Decimal {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in case we want to inspect the value.

return d.value
}

// String() is used to override fmt.Sprint(val), where val type is *decimal.Decimal
// This is particularly useful for Snowflake because we're writing all the values as STRINGS into TSV format.
// This function guarantees backwards compatibility.
Expand All @@ -53,9 +58,5 @@ func (d *Decimal) String() string {
}

func (d *Decimal) Details() DecimalDetails {
precision := PrecisionNotSpecified
if d.precision != nil {
precision = *d.precision
}
return DecimalDetails{scale: d.Scale(), precision: precision}
return DecimalDetails{scale: d.Scale(), precision: d.precision}
}
56 changes: 25 additions & 31 deletions lib/typing/decimal/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,46 @@ import (
"testing"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/stretchr/testify/assert"
)

func TestNewDecimal(t *testing.T) {
// Nil precision:
assert.Equal(t, "0", NewDecimal(nil, numbers.MustParseDecimal("0")).String())
assert.Equal(t, "0", NewDecimal(numbers.MustParseDecimal("0")).String())
}

func TestNewDecimalWithPrecision(t *testing.T) {
// Precision = -1 (PrecisionNotSpecified):
assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12.34")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), PrecisionNotSpecified).Details())
// Precision = scale:
assert.Equal(t, DecimalDetails{scale: 2, precision: 2}, NewDecimal(ptr.ToInt32(2), numbers.MustParseDecimal("12.34")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: 2}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 2).Details())
// Precision < scale:
assert.Equal(t, DecimalDetails{scale: 2, precision: 3}, NewDecimal(ptr.ToInt32(1), numbers.MustParseDecimal("12.34")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: 3}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 1).Details())
// Precision > scale:
assert.Equal(t, DecimalDetails{scale: 2, precision: 4}, NewDecimal(ptr.ToInt32(4), numbers.MustParseDecimal("12.34")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: 4}, NewDecimalWithPrecision(numbers.MustParseDecimal("12.34"), 4).Details())
}

func TestDecimal_Scale(t *testing.T) {
assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("0")).Scale())
assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("12345")).Scale())
assert.Equal(t, int32(0), NewDecimal(nil, numbers.MustParseDecimal("12300")).Scale())
assert.Equal(t, int32(1), NewDecimal(nil, numbers.MustParseDecimal("12300.0")).Scale())
assert.Equal(t, int32(2), NewDecimal(nil, numbers.MustParseDecimal("12300.00")).Scale())
assert.Equal(t, int32(2), NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Scale())
assert.Equal(t, int32(3), NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Scale())
assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("0")).Scale())
assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("12345")).Scale())
assert.Equal(t, int32(0), NewDecimal(numbers.MustParseDecimal("12300")).Scale())
assert.Equal(t, int32(1), NewDecimal(numbers.MustParseDecimal("12300.0")).Scale())
assert.Equal(t, int32(2), NewDecimal(numbers.MustParseDecimal("12300.00")).Scale())
assert.Equal(t, int32(2), NewDecimal(numbers.MustParseDecimal("12345.12")).Scale())
assert.Equal(t, int32(3), NewDecimal(numbers.MustParseDecimal("-12345.123")).Scale())
}

func TestDecimal_Details(t *testing.T) {
// Nil precision:
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("0")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Details())
assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Details())

// -1 precision (PrecisionNotSpecified):
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("0")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345.12")).Details())
assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12345.123")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("0")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("12345")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(numbers.MustParseDecimal("-12")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(numbers.MustParseDecimal("12345.12")).Details())
assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(numbers.MustParseDecimal("-12345.123")).Details())

// 10 precision:
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("0")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12")).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345.12")).Details())
assert.Equal(t, DecimalDetails{scale: 3, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12345.123")).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("0"), 10).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("12345"), 10).Details())
assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("-12"), 10).Details())
assert.Equal(t, DecimalDetails{scale: 2, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("12345.12"), 10).Details())
assert.Equal(t, DecimalDetails{scale: 3, precision: 10}, NewDecimalWithPrecision(numbers.MustParseDecimal("-12345.123"), 10).Details())
}
3 changes: 1 addition & 2 deletions lib/typing/values/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/decimal"
Expand Down Expand Up @@ -123,7 +122,7 @@ func TestToString(t *testing.T) {
assert.Equal(t, "123.45", val)

// Decimals
value := decimal.NewDecimal(ptr.ToInt32(38), numbers.MustParseDecimal("585692791691858.25"))
value := decimal.NewDecimalWithPrecision(numbers.MustParseDecimal("585692791691858.25"), 38)
val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil)
assert.NoError(t, err)
assert.Equal(t, "585692791691858.25", val)
Expand Down