Skip to content

Commit

Permalink
Little cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Jun 26, 2024
1 parent c1bd606 commit 55c4ace
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 85 deletions.
10 changes: 5 additions & 5 deletions lib/debezium/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"math/big"
"slices"

"github.com/artie-labs/transfer/lib/typing/decimal"
"github.com/artie-labs/transfer/lib/numbers"
"github.com/cockroachdb/apd/v3"
)

Expand Down Expand Up @@ -75,12 +75,12 @@ func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) {

// EncodeDecimalWithScale is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal].
// using a specific scale.
func EncodeDecimalWithScale(_decimal *apd.Decimal, scale int32) []byte {
func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte {
targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative.
if _decimal.Exponent != targetExponent {
_decimal = decimal.DecimalWithNewExponent(_decimal, targetExponent)
if decimal.Exponent != targetExponent {
decimal = numbers.DecimalWithNewExponent(decimal, targetExponent)
}
bytes, _ := EncodeDecimal(_decimal)
bytes, _ := EncodeDecimal(decimal)
return bytes
}

Expand Down
39 changes: 39 additions & 0 deletions lib/numbers/decimal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package numbers

import "github.com/cockroachdb/apd/v3"

// MustParseDecimal parses a string to an [apd.Decimal] or panics -- used for tests.
func MustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

// DecimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent.
// If the new exponent is less precise then the extra digits will be truncated.
func DecimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal {
exponentDelta := newExponent - decimal.Exponent // Exponent is negative.

if exponentDelta == 0 {
return new(apd.Decimal).Set(decimal)
}

coefficient := new(apd.BigInt).Set(&decimal.Coeff)

if exponentDelta < 0 {
multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil)
coefficient.Mul(coefficient, multiplier)
} else if exponentDelta > 0 {
divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil)
coefficient.Div(coefficient, divisor)
}

return &apd.Decimal{
Form: decimal.Form,
Negative: decimal.Negative,
Exponent: newExponent,
Coeff: *coefficient,
}
}
28 changes: 28 additions & 0 deletions lib/numbers/decimal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package numbers

import (
"testing"

"github.com/cockroachdb/apd/v3"
"github.com/stretchr/testify/assert"
)

func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 0), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 1), 1).Text('f'))
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 100), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 0), 1).Text('f'))
assert.Equal(t, "0.0", DecimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", DecimalWithNewExponent(MustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", DecimalWithNewExponent(MustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", DecimalWithNewExponent(MustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", DecimalWithNewExponent(MustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", DecimalWithNewExponent(MustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", DecimalWithNewExponent(MustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", DecimalWithNewExponent(MustParseDecimal("12.349"), 1).Text('f'))
}
12 changes: 2 additions & 10 deletions lib/parquetutil/parse_values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,16 @@ package parquetutil
import (
"testing"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing/ext"
"github.com/cockroachdb/apd/v3"

"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/artie-labs/transfer/lib/typing/decimal"
"github.com/stretchr/testify/assert"
)

func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

func TestParseValue(t *testing.T) {
eDecimal := typing.EDecimal
eDecimal.ExtendedDecimalDetails = decimal.NewDecimal(ptr.ToInt(30), 5, nil)
Expand Down Expand Up @@ -74,7 +66,7 @@ func TestParseValue(t *testing.T) {
},
{
name: "decimal",
colVal: decimal.NewDecimal(ptr.ToInt(30), 5, mustParseDecimal("5000.2232")),
colVal: decimal.NewDecimal(ptr.ToInt(30), 5, numbers.MustParseDecimal("5000.2232")),
colKind: columns.NewColumn("", eDecimal),
expectedValue: "5000.22320",
},
Expand Down
76 changes: 76 additions & 0 deletions lib/test/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package main

import (
"fmt"
"math/rand"
"strings"

"github.com/artie-labs/transfer/lib/debezium"
"github.com/cockroachdb/apd/v3"
)

func mustEncodeAndDecodeDecimal(decimal *apd.Decimal, scale int32) string {
bytes := debezium.EncodeDecimalWithScale(decimal, scale)
return debezium.DecodeDecimal(bytes, scale).Text('f')
}

func randDigit() (byte, bool) {
offset := rand.Intn(10)
return byte(48 + offset), offset == 0
}

func generateNumberWithScale(maxDigitsBefore int, maxDigitsAfter int) (*apd.Decimal, int32) {
out := strings.Builder{}

var wroteNonZero bool
for range rand.Intn(maxDigitsBefore + 1) {
digit, isZero := randDigit()
if isZero && !wroteNonZero {
continue
}
wroteNonZero = true
out.WriteByte(digit)
}

if !wroteNonZero {
out.WriteRune('0')
}

scale := rand.Intn(maxDigitsAfter + 1)
if scale > 0 {
out.WriteRune('.')

for range scale {
digit, isZero := randDigit()
if !isZero {
wroteNonZero = true
}
out.WriteByte(digit)
}
}

stringValue := out.String()

if wroteNonZero && rand.Intn(2) == 1 {
stringValue = "-" + stringValue
}

decimal, _, err := apd.NewFromString(stringValue)
if err != nil {
panic(err)
}
return decimal, -decimal.Exponent
}

func main() {
for i := range 1000 {
fmt.Printf("Checking batch %d...\n", i)
for range 1_000_000 {
in, scale := generateNumberWithScale(30, 30)
out := mustEncodeAndDecodeDecimal(in, scale)
if in.Text('f') != out {
panic(fmt.Sprintf("Failed for %s -> %s", in.Text('f'), out))
}
}
}
}
34 changes: 2 additions & 32 deletions lib/typing/decimal/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package decimal
import (
"fmt"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/cockroachdb/apd/v3"
)
Expand Down Expand Up @@ -55,15 +56,11 @@ func (d *Decimal) String() string {
targetExponent := -int32(d.scale)
value := d.value
if value.Exponent != targetExponent {
value = DecimalWithNewExponent(value, targetExponent)
value = numbers.DecimalWithNewExponent(value, targetExponent)
}
return value.Text('f')
}

func (d *Decimal) Value() *apd.Decimal {
return d.value
}

// SnowflakeKind - is used to determine whether a NUMERIC data type should be a STRING or NUMERIC(p, s).
func (d *Decimal) SnowflakeKind() string {
return d.toKind(MaxPrecisionBeforeString, "STRING")
Expand All @@ -90,30 +87,3 @@ func (d *Decimal) BigQueryKind() string {

return "STRING"
}

// DecimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent.
// If the new exponent is less precise then the extra digits will be truncated.
func DecimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal {
exponentDelta := newExponent - decimal.Exponent // Exponent is negative.

if exponentDelta == 0 {
return new(apd.Decimal).Set(decimal)
}

coefficient := new(apd.BigInt).Set(&decimal.Coeff)

if exponentDelta < 0 {
multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil)
coefficient.Mul(coefficient, multiplier)
} else if exponentDelta > 0 {
divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil)
coefficient.Div(coefficient, divisor)
}

return &apd.Decimal{
Form: decimal.Form,
Negative: decimal.Negative,
Exponent: newExponent,
Coeff: *coefficient,
}
}
29 changes: 0 additions & 29 deletions lib/typing/decimal/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package decimal
import (
"testing"

"github.com/cockroachdb/apd/v3"
"github.com/stretchr/testify/assert"

"github.com/artie-labs/transfer/lib/ptr"
Expand Down Expand Up @@ -77,31 +76,3 @@ func TestDecimalKind(t *testing.T) {
assert.Equal(t, testCase.ExpectedBigQueryKind, d.BigQueryKind(), testCase.Name)
}
}

func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 0), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 1), 1).Text('f'))
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 100), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 0), 1).Text('f'))
assert.Equal(t, "0.0", DecimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", DecimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", DecimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", DecimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", DecimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", DecimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", DecimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", DecimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f'))
}
12 changes: 3 additions & 9 deletions lib/typing/values/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ func TestBooleanToBit(t *testing.T) {
assert.Equal(t, 0, BooleanToBit(false))
}

func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

func TestToString(t *testing.T) {
{
// Nil value
Expand Down Expand Up @@ -131,7 +123,9 @@ func TestToString(t *testing.T) {
assert.Equal(t, "123.45", val)

// Decimals
value := decimal.NewDecimal(ptr.ToInt(38), 2, mustParseDecimal("585692791691858.25"))
_decimal, _, err := apd.NewFromString("585692791691858.25")
assert.NoError(t, err)
value := decimal.NewDecimal(ptr.ToInt(38), 2, _decimal)
val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil)
assert.NoError(t, err)
assert.Equal(t, "585692791691858.25", val)
Expand Down

0 comments on commit 55c4ace

Please sign in to comment.