Skip to content

Commit

Permalink
Move adp.Decimal functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Jun 26, 2024
1 parent fd08cfd commit 197fd5b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 31 deletions.
34 changes: 3 additions & 31 deletions lib/debezium/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"math/big"
"testing"

"github.com/cockroachdb/apd/v3"
"github.com/artie-labs/transfer/lib/numbers"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -41,37 +41,9 @@ func TestDecodeBigInt(t *testing.T) {
}
}

func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 0), 0).Text('f'))
assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 1), 1).Text('f'))
assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 100), 0).Text('f'))
assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 0), 1).Text('f'))
assert.Equal(t, "0.0", decimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", decimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", decimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", decimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", decimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", decimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", decimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", decimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f'))
}

func TestEncodeDecimal(t *testing.T) {
testEncodeDecimal := func(value string, expectedScale int32) {
bytes, scale := EncodeDecimal(mustParseDecimal(value))
bytes, scale := EncodeDecimal(numbers.MustParseDecimal(value))
actual := DecodeDecimal(bytes, nil, int(scale)).String()
assert.Equal(t, value, actual, value)
assert.Equal(t, expectedScale, scale, value)
Expand All @@ -91,7 +63,7 @@ func TestEncodeDecimal(t *testing.T) {

func TestEncodeDecimalWithScale(t *testing.T) {
mustEncodeAndDecodeDecimal := func(value string, scale int32) string {
bytes := EncodeDecimalWithScale(mustParseDecimal(value), scale)
bytes := EncodeDecimalWithScale(numbers.MustParseDecimal(value), scale)
return DecodeDecimal(bytes, nil, int(scale)).String()
}

Expand Down
39 changes: 39 additions & 0 deletions lib/numbers/decimal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package numbers

import "github.com/cockroachdb/apd/v3"

// MustParseDecimal parses a string to an [apd.Decimal] or panics -- used for tests.
func MustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

// DecimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent.
// If the new exponent is less precise then the extra digits will be truncated.
func DecimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal {
exponentDelta := newExponent - decimal.Exponent // Exponent is negative.

if exponentDelta == 0 {
return new(apd.Decimal).Set(decimal)
}

coefficient := new(apd.BigInt).Set(&decimal.Coeff)

if exponentDelta < 0 {
multiplier := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(-exponentDelta)), nil)
coefficient.Mul(coefficient, multiplier)
} else if exponentDelta > 0 {
divisor := new(apd.BigInt).Exp(apd.NewBigInt(10), apd.NewBigInt(int64(exponentDelta)), nil)
coefficient.Div(coefficient, divisor)
}

return &apd.Decimal{
Form: decimal.Form,
Negative: decimal.Negative,
Exponent: newExponent,
Coeff: *coefficient,
}
}
28 changes: 28 additions & 0 deletions lib/numbers/decimal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package numbers

import (
"testing"

"github.com/cockroachdb/apd/v3"
"github.com/stretchr/testify/assert"
)

func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 0), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 1), 1).Text('f'))
assert.Equal(t, "0", DecimalWithNewExponent(apd.New(0, 100), 0).Text('f'))
assert.Equal(t, "00", DecimalWithNewExponent(apd.New(0, 0), 1).Text('f'))
assert.Equal(t, "0.0", DecimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", DecimalWithNewExponent(MustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", DecimalWithNewExponent(MustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", DecimalWithNewExponent(MustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", DecimalWithNewExponent(MustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", DecimalWithNewExponent(MustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", DecimalWithNewExponent(MustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", DecimalWithNewExponent(MustParseDecimal("12.349"), 1).Text('f'))
}

0 comments on commit 197fd5b

Please sign in to comment.