From afb33f59330cd98f13e65ffeeb1bbeaf2a3843f6 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:18:57 -0700 Subject: [PATCH] [typing] Remove pointer for `DecimalDetails` precision --- clients/bigquery/dialect/dialect_test.go | 4 +-- clients/mssql/dialect/dialect_test.go | 2 +- clients/redshift/dialect/dialect_test.go | 2 +- clients/snowflake/dialect/dialect_test.go | 4 +-- lib/cdc/util/optional_schema_test.go | 19 ++++++------ lib/debezium/schema.go | 9 ++++-- lib/debezium/schema_test.go | 4 +-- lib/optimization/event_update_test.go | 10 +++--- lib/parquetutil/parse_values_test.go | 2 +- lib/typing/decimal/base.go | 17 ++++------ lib/typing/decimal/decimal.go | 6 +++- lib/typing/decimal/decimal_test.go | 38 +++++++++++------------ lib/typing/decimal/details.go | 24 ++++++-------- lib/typing/decimal/details_test.go | 4 +-- lib/typing/numeric.go | 2 +- lib/typing/numeric_test.go | 21 +++++-------- lib/typing/parquet.go | 4 +-- 17 files changed, 82 insertions(+), 90 deletions(-) diff --git a/clients/bigquery/dialect/dialect_test.go b/clients/bigquery/dialect/dialect_test.go index 74fdc34b2..5ce08c510 100644 --- a/clients/bigquery/dialect/dialect_test.go +++ b/clients/bigquery/dialect/dialect_test.go @@ -102,14 +102,14 @@ func TestBigQueryDialect_KindForDataType(t *testing.T) { kd, err := dialect.KindForDataType("numeric(5, 2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(5), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(5), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } { kd, err := dialect.KindForDataType("bignumeric(5, 2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(5), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(5), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } } diff --git a/clients/mssql/dialect/dialect_test.go b/clients/mssql/dialect/dialect_test.go index fe9402a2f..7c4996ea6 100644 --- a/clients/mssql/dialect/dialect_test.go +++ b/clients/mssql/dialect/dialect_test.go @@ -85,7 +85,7 @@ func TestMSSQLDialect_KindForDataType(t *testing.T) { kd, err := dialect.KindForDataType("numeric(5, 2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(5), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(5), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } { diff --git a/clients/redshift/dialect/dialect_test.go b/clients/redshift/dialect/dialect_test.go index 734a1af7f..435401f6f 100644 --- a/clients/redshift/dialect/dialect_test.go +++ b/clients/redshift/dialect/dialect_test.go @@ -145,7 +145,7 @@ func TestRedshiftDialect_KindForDataType(t *testing.T) { kd, err := dialect.KindForDataType("numeric(5,2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(5), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(5), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } } diff --git a/clients/snowflake/dialect/dialect_test.go b/clients/snowflake/dialect/dialect_test.go index a1fdcded5..deae75a8a 100644 --- a/clients/snowflake/dialect/dialect_test.go +++ b/clients/snowflake/dialect/dialect_test.go @@ -84,14 +84,14 @@ func TestSnowflakeDialect_KindForDataType_Floats(t *testing.T) { kd, err := SnowflakeDialect{}.KindForDataType("NUMERIC(38, 2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(38), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(38), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } { kd, err := SnowflakeDialect{}.KindForDataType("NUMBER(38, 2)", "") assert.NoError(t, err) assert.Equal(t, typing.EDecimal.Kind, kd.Kind) - assert.Equal(t, int32(38), *kd.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(38), kd.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), kd.ExtendedDecimalDetails.Scale()) } { diff --git a/lib/cdc/util/optional_schema_test.go b/lib/cdc/util/optional_schema_test.go index 3e739a1cf..d4d045b53 100644 --- a/lib/cdc/util/optional_schema_test.go +++ b/lib/cdc/util/optional_schema_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/assert" - "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/decimal" ) @@ -67,35 +66,35 @@ func TestGetOptionalSchema(t *testing.T) { "bit_test": typing.Boolean, "numeric_test": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(decimal.PrecisionNotSpecified), decimal.DefaultScale), + ExtendedDecimalDetails: decimal.NewDecimalDetails(decimal.PrecisionNotSpecified, decimal.DefaultScale), }, "numeric_5": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(5), 0), + ExtendedDecimalDetails: decimal.NewDecimalDetails(5, 0), }, "numeric_5_2": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(5), 2), + ExtendedDecimalDetails: decimal.NewDecimalDetails(5, 2), }, "numeric_5_6": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(5), 6), + ExtendedDecimalDetails: decimal.NewDecimalDetails(5, 6), }, "numeric_5_0": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(5), 0), + ExtendedDecimalDetails: decimal.NewDecimalDetails(5, 0), }, "numeric_39_0": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(39), 0), + ExtendedDecimalDetails: decimal.NewDecimalDetails(39, 0), }, "numeric_39_2": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(39), 2), + ExtendedDecimalDetails: decimal.NewDecimalDetails(39, 2), }, "numeric_39_6": { Kind: typing.EDecimal.Kind, - ExtendedDecimalDetails: decimal.NewDecimalDetails(ptr.ToInt32(39), 6), + ExtendedDecimalDetails: decimal.NewDecimalDetails(39, 6), }, }, }, @@ -116,7 +115,7 @@ func TestGetOptionalSchema(t *testing.T) { if expectedValue.ExtendedDecimalDetails != nil || actualVal.ExtendedDecimalDetails != nil { assert.NotNil(t, actualVal.ExtendedDecimalDetails, testMsg) assert.Equal(t, expectedValue.ExtendedDecimalDetails.Scale(), actualVal.ExtendedDecimalDetails.Scale(), testMsg) - assert.Equal(t, *expectedValue.ExtendedDecimalDetails.Precision(), *actualVal.ExtendedDecimalDetails.Precision(), testMsg) + assert.Equal(t, expectedValue.ExtendedDecimalDetails.Precision(), actualVal.ExtendedDecimalDetails.Precision(), testMsg) } else { assert.Nil(t, actualVal.ExtendedDecimalDetails, testMsg) } diff --git a/lib/debezium/schema.go b/lib/debezium/schema.go index d33dc90c6..8fa2ddfa1 100644 --- a/lib/debezium/schema.go +++ b/lib/debezium/schema.go @@ -93,11 +93,16 @@ func (f Field) ToKindDetails() typing.KindDetails { case JSON, GeometryPointType, GeometryType, GeographyType: return typing.Struct case KafkaDecimalType: - scale, precision, err := f.GetScaleAndPrecision() + scale, precisionPtr, err := f.GetScaleAndPrecision() if err != nil { return typing.Invalid } + var precision int32 = decimal.PrecisionNotSpecified + if precisionPtr != nil { + precision = *precisionPtr + } + eDecimal := typing.EDecimal eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(precision, scale) return eDecimal @@ -106,7 +111,7 @@ func (f Field) ToKindDetails() typing.KindDetails { // This is because scale is not specified at the column level, rather at the row level // It shouldn't matter much anyway since the column type we are creating is `TEXT` to avoid boundary errors. eDecimal := typing.EDecimal - eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(decimal.PrecisionNotSpecified), decimal.DefaultScale) + eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(decimal.PrecisionNotSpecified, decimal.DefaultScale) return eDecimal } diff --git a/lib/debezium/schema_test.go b/lib/debezium/schema_test.go index 41ae840e9..5499c5af9 100644 --- a/lib/debezium/schema_test.go +++ b/lib/debezium/schema_test.go @@ -89,10 +89,10 @@ func TestField_ToKindDetails(t *testing.T) { } eDecimal := typing.EDecimal - eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(decimal.PrecisionNotSpecified), decimal.DefaultScale) + eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(decimal.PrecisionNotSpecified, decimal.DefaultScale) kafkaDecimalType := typing.EDecimal - kafkaDecimalType.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(10), 5) + kafkaDecimalType.ExtendedDecimalDetails = decimal.NewDecimalDetails(10, 5) tcs := []_tc{ { diff --git a/lib/optimization/event_update_test.go b/lib/optimization/event_update_test.go index b43d60377..50d30a556 100644 --- a/lib/optimization/event_update_test.go +++ b/lib/optimization/event_update_test.go @@ -42,7 +42,7 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { tableDataCols.AddColumn(columns.NewColumn("ext_dec", typing.String)) extDecimalType := typing.EDecimal - extDecimalType.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(22), 2) + extDecimalType.ExtendedDecimalDetails = decimal.NewDecimalDetails(22, 2) tableDataCols.AddColumn(columns.NewColumn("ext_dec_filled", extDecimalType)) tableDataCols.AddColumn(columns.NewColumn(strCol, typing.String)) @@ -121,14 +121,14 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { assert.Equal(t, typing.String, extDecCol.KindDetails) extDecimal := typing.EDecimal - extDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(30), 2) + extDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(30, 2) assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec", extDecimal))) // Now it should be ext decimal type extDecCol, isOk = tableData.inMemoryColumns.GetColumn("ext_dec") assert.True(t, isOk) assert.Equal(t, typing.EDecimal.Kind, extDecCol.KindDetails.Kind) // Check precision and scale too. - assert.Equal(t, int32(30), *extDecCol.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(30), extDecCol.KindDetails.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), extDecCol.KindDetails.ExtendedDecimalDetails.Scale()) // Testing ext_dec_filled since it's already filled out @@ -136,7 +136,7 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { assert.True(t, isOk) assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) // Check precision and scale too. - assert.Equal(t, int32(22), *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(22), extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) assert.NoError(t, tableData.MergeColumnsFromDestination(columns.NewColumn("ext_dec_filled", extDecimal))) @@ -144,7 +144,7 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { assert.True(t, isOk) assert.Equal(t, typing.EDecimal.Kind, extDecColFilled.KindDetails.Kind) // Check precision and scale too. - assert.Equal(t, int32(22), *extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) + assert.Equal(t, int32(22), extDecColFilled.KindDetails.ExtendedDecimalDetails.Precision()) assert.Equal(t, int32(2), extDecColFilled.KindDetails.ExtendedDecimalDetails.Scale()) } { diff --git a/lib/parquetutil/parse_values_test.go b/lib/parquetutil/parse_values_test.go index 06ed51d48..403af9127 100644 --- a/lib/parquetutil/parse_values_test.go +++ b/lib/parquetutil/parse_values_test.go @@ -15,7 +15,7 @@ import ( func TestParseValue(t *testing.T) { eDecimal := typing.EDecimal - eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(ptr.ToInt32(30), 5) + eDecimal.ExtendedDecimalDetails = decimal.NewDecimalDetails(30, 5) eTime := typing.ETime eTime.ExtendedTimeDetails = &ext.Time diff --git a/lib/typing/decimal/base.go b/lib/typing/decimal/base.go index bc435ff99..c2f7402fe 100644 --- a/lib/typing/decimal/base.go +++ b/lib/typing/decimal/base.go @@ -7,7 +7,7 @@ import ( ) func (d *DecimalDetails) isNumeric() bool { - if d.precision == nil || *d.precision == PrecisionNotSpecified { + if d.precision == PrecisionNotSpecified { return false } @@ -17,11 +17,11 @@ func (d *DecimalDetails) isNumeric() bool { } // max(1,s) <= p <= s + 29 - return numbers.BetweenEq(max(1, d.scale), d.scale+29, *d.precision) + return numbers.BetweenEq(max(1, d.scale), d.scale+29, d.precision) } func (d *DecimalDetails) isBigNumeric() bool { - if d.precision == nil || *d.precision == -1 { + if d.precision == PrecisionNotSpecified { return false } @@ -31,18 +31,13 @@ func (d *DecimalDetails) isBigNumeric() bool { } // max(1,s) <= p <= s + 38 - return numbers.BetweenEq(max(1, d.scale), d.scale+38, *d.precision) + return numbers.BetweenEq(max(1, d.scale), d.scale+38, d.precision) } func (d *DecimalDetails) toKind(maxPrecision int32, exceededKind string) string { - precision := maxPrecision - if d.precision != nil { - precision = *d.precision - } - - if precision > maxPrecision || precision == -1 { + if d.precision > maxPrecision || d.precision == PrecisionNotSpecified { return exceededKind } - return fmt.Sprintf("NUMERIC(%v, %v)", precision, d.scale) + return fmt.Sprintf("NUMERIC(%v, %v)", d.precision, d.scale) } diff --git a/lib/typing/decimal/decimal.go b/lib/typing/decimal/decimal.go index d53dde072..7bddeacf4 100644 --- a/lib/typing/decimal/decimal.go +++ b/lib/typing/decimal/decimal.go @@ -72,5 +72,9 @@ func (d *Decimal) Value() any { } func (d *Decimal) Details() DecimalDetails { - return DecimalDetails{scale: d.Scale(), precision: d.precision} + var precision int32 = PrecisionNotSpecified + if d.precision != nil { + precision = *d.precision + } + return DecimalDetails{scale: d.Scale(), precision: precision} } diff --git a/lib/typing/decimal/decimal_test.go b/lib/typing/decimal/decimal_test.go index c0403fe23..a2172cbc5 100644 --- a/lib/typing/decimal/decimal_test.go +++ b/lib/typing/decimal/decimal_test.go @@ -12,13 +12,13 @@ func TestNewDecimal(t *testing.T) { // Nil precision: assert.Equal(t, "0", NewDecimal(nil, numbers.MustParseDecimal("0")).String()) // Precision = -1 (PrecisionNotSpecified): - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12.34")).Details()) // Precision = scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(2)}, NewDecimal(ptr.ToInt32(2), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 2}, NewDecimal(ptr.ToInt32(2), numbers.MustParseDecimal("12.34")).Details()) // Precision < scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(3)}, NewDecimal(ptr.ToInt32(1), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 3}, NewDecimal(ptr.ToInt32(1), numbers.MustParseDecimal("12.34")).Details()) // Precision > scale: - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(4)}, NewDecimal(ptr.ToInt32(4), numbers.MustParseDecimal("12.34")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 4}, NewDecimal(ptr.ToInt32(4), numbers.MustParseDecimal("12.34")).Details()) } func TestDecimal_Scale(t *testing.T) { @@ -33,23 +33,23 @@ func TestDecimal_Scale(t *testing.T) { func TestDecimal_Details(t *testing.T) { // Nil precision: - assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0}, NewDecimal(nil, numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2}, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3}, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(nil, numbers.MustParseDecimal("-12345.123")).Details()) // -1 precision (PrecisionNotSpecified): - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3, precision: ptr.ToInt32(-1)}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12345.123")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: -1}, NewDecimal(ptr.ToInt32(-1), numbers.MustParseDecimal("-12345.123")).Details()) // 10 precision: - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(10)}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("0")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(10)}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345")).Details()) - assert.Equal(t, DecimalDetails{scale: 0, precision: ptr.ToInt32(10)}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12")).Details()) - assert.Equal(t, DecimalDetails{scale: 2, precision: ptr.ToInt32(10)}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345.12")).Details()) - assert.Equal(t, DecimalDetails{scale: 3, precision: ptr.ToInt32(10)}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12345.123")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("0")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345")).Details()) + assert.Equal(t, DecimalDetails{scale: 0, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12")).Details()) + assert.Equal(t, DecimalDetails{scale: 2, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("12345.12")).Details()) + assert.Equal(t, DecimalDetails{scale: 3, precision: 10}, NewDecimal(ptr.ToInt32(10), numbers.MustParseDecimal("-12345.123")).Details()) } diff --git a/lib/typing/decimal/details.go b/lib/typing/decimal/details.go index 85e93688a..74c9b5019 100644 --- a/lib/typing/decimal/details.go +++ b/lib/typing/decimal/details.go @@ -2,24 +2,20 @@ package decimal import ( "fmt" - - "github.com/artie-labs/transfer/lib/ptr" ) type DecimalDetails struct { scale int32 - precision *int32 + precision int32 } -func NewDecimalDetails(precision *int32, scale int32) *DecimalDetails { - if precision != nil { - if scale > *precision && *precision != -1 { - // Note: -1 precision means it's not specified. +func NewDecimalDetails(precision int32, scale int32) *DecimalDetails { + if scale > precision && precision != -1 { + // Note: -1 precision means it's not specified. - // This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6) - // We are setting precision to be scale + 1 to account for the leading zero for decimal numbers. - precision = ptr.ToInt32(scale + 1) - } + // This is typically not possible, but Postgres has a design flaw that allows you to do things like: NUMERIC(5, 6) which actually equates to NUMERIC(7, 6) + // We are setting precision to be scale + 1 to account for the leading zero for decimal numbers. + precision = scale + 1 } return &DecimalDetails{ @@ -32,7 +28,7 @@ func (d DecimalDetails) Scale() int32 { return d.scale } -func (d DecimalDetails) Precision() *int32 { +func (d DecimalDetails) Precision() int32 { return d.precision } @@ -55,9 +51,9 @@ func (d *DecimalDetails) RedshiftKind() string { // BigQueryKind - is inferring logic from: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#decimal_types func (d *DecimalDetails) BigQueryKind() string { if d.isNumeric() { - return fmt.Sprintf("NUMERIC(%v, %v)", *d.precision, d.scale) + return fmt.Sprintf("NUMERIC(%v, %v)", d.precision, d.scale) } else if d.isBigNumeric() { - return fmt.Sprintf("BIGNUMERIC(%v, %v)", *d.precision, d.scale) + return fmt.Sprintf("BIGNUMERIC(%v, %v)", d.precision, d.scale) } return "STRING" diff --git a/lib/typing/decimal/details_test.go b/lib/typing/decimal/details_test.go index 20b8fb36f..37b9bcc21 100644 --- a/lib/typing/decimal/details_test.go +++ b/lib/typing/decimal/details_test.go @@ -4,8 +4,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/artie-labs/transfer/lib/ptr" ) func TestDecimalDetailsKind(t *testing.T) { @@ -70,7 +68,7 @@ func TestDecimalDetailsKind(t *testing.T) { } for _, testCase := range testCases { - d := NewDecimalDetails(ptr.ToInt32(testCase.Precision), testCase.Scale) + d := NewDecimalDetails(testCase.Precision, testCase.Scale) assert.Equal(t, testCase.ExpectedSnowflakeKind, d.SnowflakeKind(), testCase.Name) assert.Equal(t, testCase.ExpectedRedshiftKind, d.RedshiftKind(), testCase.Name) assert.Equal(t, testCase.ExpectedBigQueryKind, d.BigQueryKind(), testCase.Name) diff --git a/lib/typing/numeric.go b/lib/typing/numeric.go index 9378d0a09..7651c49f3 100644 --- a/lib/typing/numeric.go +++ b/lib/typing/numeric.go @@ -28,6 +28,6 @@ func ParseNumeric(parts []string) KindDetails { } eDec := EDecimal - eDec.ExtendedDecimalDetails = decimal.NewDecimalDetails(&parsedNumbers[0], parsedNumbers[1]) + eDec.ExtendedDecimalDetails = decimal.NewDecimalDetails(parsedNumbers[0], parsedNumbers[1]) return eDec } diff --git a/lib/typing/numeric_test.go b/lib/typing/numeric_test.go index 3dba2f307..f35a5f1c7 100644 --- a/lib/typing/numeric_test.go +++ b/lib/typing/numeric_test.go @@ -5,8 +5,6 @@ import ( "math" "testing" - "github.com/artie-labs/transfer/lib/ptr" - "github.com/stretchr/testify/assert" ) @@ -14,7 +12,7 @@ func TestParseNumeric(t *testing.T) { type _testCase struct { parameters []string expectedKindDetails KindDetails - expectedPrecision *int32 // Using a pointer to int32 so we can differentiate between unset (nil) and set (0 included) + expectedPrecision int32 expectedScale int32 } @@ -42,37 +40,37 @@ func TestParseNumeric(t *testing.T) { { parameters: []string{"5", " 2"}, expectedKindDetails: EDecimal, - expectedPrecision: ptr.ToInt32(5), + expectedPrecision: 5, expectedScale: 2, }, { parameters: []string{"5", "2"}, expectedKindDetails: EDecimal, - expectedPrecision: ptr.ToInt32(5), + expectedPrecision: 5, expectedScale: 2, }, { parameters: []string{"39", "6"}, expectedKindDetails: EDecimal, - expectedPrecision: ptr.ToInt32(39), + expectedPrecision: 39, expectedScale: 6, }, { parameters: []string{"5"}, expectedKindDetails: Integer, - expectedPrecision: ptr.ToInt32(5), + expectedPrecision: 5, expectedScale: 0, }, { parameters: []string{"5", "0"}, expectedKindDetails: Integer, - expectedPrecision: ptr.ToInt32(5), + expectedPrecision: 5, expectedScale: 0, }, { parameters: []string{fmt.Sprint(math.MaxInt32), fmt.Sprint(math.MaxInt32)}, expectedKindDetails: EDecimal, - expectedPrecision: ptr.ToInt32(math.MaxInt32), + expectedPrecision: math.MaxInt32, expectedScale: math.MaxInt32, }, } @@ -82,10 +80,7 @@ func TestParseNumeric(t *testing.T) { assert.Equal(t, testCase.expectedKindDetails.Kind, result.Kind, testCase.parameters) if result.ExtendedDecimalDetails != nil { assert.Equal(t, testCase.expectedScale, result.ExtendedDecimalDetails.Scale(), testCase.parameters) - - if result.ExtendedDecimalDetails.Precision() != nil { - assert.Equal(t, *testCase.expectedPrecision, *result.ExtendedDecimalDetails.Precision(), testCase.parameters) - } + assert.Equal(t, testCase.expectedPrecision, result.ExtendedDecimalDetails.Precision(), testCase.parameters) } } diff --git a/lib/typing/parquet.go b/lib/typing/parquet.go index 4a0378302..f728e8b38 100644 --- a/lib/typing/parquet.go +++ b/lib/typing/parquet.go @@ -113,7 +113,7 @@ func (k *KindDetails) ParquetAnnotation(colName string) (*Field, error) { }, nil case EDecimal.Kind: precision := k.ExtendedDecimalDetails.Precision() - if precision == nil || *precision == -1 { + if precision == -1 { // This is a variable precision decimal, so we'll just treat it as a string. return &Field{ Tag: FieldTag{ @@ -132,7 +132,7 @@ func (k *KindDetails) ParquetAnnotation(colName string) (*Field, error) { InName: &colName, Type: ptr.ToString("BYTE_ARRAY"), ConvertedType: ptr.ToString("DECIMAL"), - Precision: ptr.ToInt(int(*precision)), + Precision: ptr.ToInt(int(precision)), Scale: ptr.ToInt(int(scale)), }.String(), }, nil