Skip to content

Commit

Permalink
Move MustParseDecimal function (#766)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Jun 26, 2024
1 parent b5dcc5a commit 61aa7cb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
6 changes: 3 additions & 3 deletions lib/debezium/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func decodeBigInt(data []byte) *big.Int {
return bigInt
}

// decimalWithNewExponent takes a [apd.Decimal] and returns a new [apd.Decimal] with a the given exponent.
// decimalWithNewExponent takes a [*apd.Decimal] and returns a new [*apd.Decimal] with a the given exponent.
// If the new exponent is less precise then the extra digits will be truncated.
func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decimal {
exponentDelta := newExponent - decimal.Exponent // Exponent is negative.
Expand All @@ -89,7 +89,7 @@ func decimalWithNewExponent(decimal *apd.Decimal, newExponent int32) *apd.Decima
}
}

// EncodeDecimal is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal].
// EncodeDecimal is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal].
// The scale of the value (which is the negated exponent of the decimal) is returned as the second argument.
func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) {
bigIntValue := decimal.Coeff.MathBigInt()
Expand All @@ -100,7 +100,7 @@ func EncodeDecimal(decimal *apd.Decimal) ([]byte, int32) {
return encodeBigInt(bigIntValue), -decimal.Exponent
}

// EncodeDecimalWithScale is used to encode a [apd.Decimal] to [org.apache.kafka.connect.data.Decimal]
// EncodeDecimalWithScale is used to encode a [*apd.Decimal] to [org.apache.kafka.connect.data.Decimal].
// using a specific scale.
func EncodeDecimalWithScale(decimal *apd.Decimal, scale int32) []byte {
targetExponent := -scale // Negate scale since [Decimal.Exponent] is negative.
Expand Down
27 changes: 10 additions & 17 deletions lib/debezium/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math/big"
"testing"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/cockroachdb/apd/v3"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -41,14 +42,6 @@ func TestDecodeBigInt(t *testing.T) {
}
}

func mustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0", decimalWithNewExponent(apd.New(0, 0), 0).Text('f'))
assert.Equal(t, "00", decimalWithNewExponent(apd.New(0, 1), 1).Text('f'))
Expand All @@ -57,21 +50,21 @@ func TestDecimalWithNewExponent(t *testing.T) {
assert.Equal(t, "0.0", decimalWithNewExponent(apd.New(0, 0), -1).Text('f'))

// Same exponent:
assert.Equal(t, "12.349", decimalWithNewExponent(mustParseDecimal("12.349"), -3).Text('f'))
assert.Equal(t, "12.349", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -3).Text('f'))
// More precise exponent:
assert.Equal(t, "12.3490", decimalWithNewExponent(mustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", decimalWithNewExponent(mustParseDecimal("12.349"), -5).Text('f'))
assert.Equal(t, "12.3490", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -4).Text('f'))
assert.Equal(t, "12.34900", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -5).Text('f'))
// Lest precise exponent:
// Extra digits should be truncated rather than rounded.
assert.Equal(t, "12.34", decimalWithNewExponent(mustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", decimalWithNewExponent(mustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", decimalWithNewExponent(mustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", decimalWithNewExponent(mustParseDecimal("12.349"), 1).Text('f'))
assert.Equal(t, "12.34", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -2).Text('f'))
assert.Equal(t, "12.3", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), -1).Text('f'))
assert.Equal(t, "12", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 0).Text('f'))
assert.Equal(t, "10", decimalWithNewExponent(numbers.MustParseDecimal("12.349"), 1).Text('f'))
}

func TestEncodeDecimal(t *testing.T) {
testEncodeDecimal := func(value string, expectedScale int32) {
bytes, scale := EncodeDecimal(mustParseDecimal(value))
bytes, scale := EncodeDecimal(numbers.MustParseDecimal(value))
actual := DecodeDecimal(bytes, nil, int(scale)).String()
assert.Equal(t, value, actual, value)
assert.Equal(t, expectedScale, scale, value)
Expand All @@ -91,7 +84,7 @@ func TestEncodeDecimal(t *testing.T) {

func TestEncodeDecimalWithScale(t *testing.T) {
mustEncodeAndDecodeDecimal := func(value string, scale int32) string {
bytes := EncodeDecimalWithScale(mustParseDecimal(value), scale)
bytes := EncodeDecimalWithScale(numbers.MustParseDecimal(value), scale)
return DecodeDecimal(bytes, nil, int(scale)).String()
}

Expand Down
11 changes: 11 additions & 0 deletions lib/numbers/numbers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
package numbers

import "github.com/cockroachdb/apd/v3"

// BetweenEq - Looks something like this. start <= number <= end
func BetweenEq[T int | int32 | int64](start, end, number T) bool {
return number >= start && number <= end
}

// MustParseDecimal parses a string to a [*apd.Decimal] or panics -- used for tests.
func MustParseDecimal(value string) *apd.Decimal {
decimal, _, err := apd.NewFromString(value)
if err != nil {
panic(err)
}
return decimal
}

0 comments on commit 61aa7cb

Please sign in to comment.