Skip to content

Commit

Permalink
More assertions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Aug 14, 2024
1 parent 363384e commit c46929f
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 22 deletions.
8 changes: 5 additions & 3 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"os"
"strings"

"github.com/artie-labs/transfer/lib/typing"

"cloud.google.com/go/bigquery"
"cloud.google.com/go/bigquery/storage/managedwriter"
"cloud.google.com/go/bigquery/storage/managedwriter/adapt"
Expand Down Expand Up @@ -94,9 +96,9 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}
}

bqTempTableID, ok := tempTableID.(TableIdentifier)
if !ok {
return fmt.Errorf("unable to cast tempTableID to BigQuery TableIdentifier")
bqTempTableID, err := typing.AssertType[TableIdentifier](tempTableID)
if err != nil {
return err
}

return s.putTable(context.Background(), bqTempTableID, tableData)
Expand Down
17 changes: 10 additions & 7 deletions clients/bigquery/storagewrite.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ func columnsToMessageDescriptor(cols []columns.Column) (*protoreflect.MessageDes
if err != nil {
return nil, fmt.Errorf("failed to build proto descriptor: %w", err)
}
messageDescriptor, ok := descriptor.(protoreflect.MessageDescriptor)
if !ok {
return nil, fmt.Errorf("proto descriptor is not a message descriptor")

messageDescriptor, err := typing.AssertType[protoreflect.MessageDescriptor](descriptor)
if err != nil {
return nil, err
}

return &messageDescriptor, nil
}

Expand Down Expand Up @@ -154,11 +156,12 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto
return nil, fmt.Errorf("expected float32/float64/int32/int64/string received %T with value %v", value, value)
}
case typing.EDecimal.Kind:
if decimalValue, ok := value.(*decimal.Decimal); ok {
message.Set(field, protoreflect.ValueOfString(decimalValue.String()))
} else {
return nil, fmt.Errorf("expected *decimal.Decimal received %T with value %v", decimalValue, decimalValue)
decimalValue, err := typing.AssertType[*decimal.Decimal](value)
if err != nil {
return nil, err
}

message.Set(field, protoreflect.ValueOfString(decimalValue.String()))
case typing.String.Kind:
var stringValue string
switch castedValue := value.(type) {
Expand Down
8 changes: 4 additions & 4 deletions clients/shared/default_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ func DefaultValue(column columns.Column, dialect sql.Dialect, additionalDateFmts
return sql.QuoteLiteral(extTime.String(column.KindDetails.ExtendedTimeDetails.Format)), nil
}
case typing.EDecimal.Kind:
val, isOk := column.DefaultValue().(*decimal.Decimal)
if !isOk {
return nil, fmt.Errorf("colVal is not type *decimal.Decimal")
decimalValue, err := typing.AssertType[*decimal.Decimal](column.DefaultValue())
if err != nil {
return nil, err
}

return val.String(), nil
return decimalValue.String(), nil
case typing.String.Kind:
return sql.QuoteLiteral(fmt.Sprint(column.DefaultValue())), nil
}
Expand Down
21 changes: 21 additions & 0 deletions clients/shared/default_value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"testing"
"time"

"github.com/artie-labs/transfer/lib/numbers"
"github.com/artie-labs/transfer/lib/typing/decimal"

"github.com/stretchr/testify/assert"

bigQueryDialect "github.com/artie-labs/transfer/clients/bigquery/dialect"
Expand Down Expand Up @@ -105,4 +108,22 @@ func TestColumn_DefaultValue(t *testing.T) {
assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, dialect))
}
}

{
// Decimal value
{
// Type *decimal.Decimal
decimalValue := decimal.NewDecimal(numbers.MustParseDecimal("3.14159"))
col := columns.NewColumnWithDefaultValue("", typing.EDecimal, decimalValue)
value, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil)
assert.NoError(t, err)
assert.Equal(t, "3.14159", value)
}
{
// Wrong type (string)
col := columns.NewColumnWithDefaultValue("", typing.EDecimal, "hello")
_, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil)
assert.ErrorContains(t, err, "expected type *decimal.Decimal, got string")
}
}
}
8 changes: 5 additions & 3 deletions lib/crypto/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"encoding/pem"
"fmt"
"os"

"github.com/artie-labs/transfer/lib/typing"
)

func LoadRSAKey(filePath string) (*rsa.PrivateKey, error) {
Expand All @@ -28,9 +30,9 @@ func ParseRSAPrivateKey(keyBytes []byte) (*rsa.PrivateKey, error) {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}

rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key, rather: %T", key)
rsaKey, err := typing.AssertType[*rsa.PrivateKey](key)
if err != nil {
return nil, err
}

return rsaKey, nil
Expand Down
9 changes: 4 additions & 5 deletions lib/parquetutil/parse_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ func ParseValue(colVal any, colKind columns.Column, additionalDateFmts []string)
}
}
case typing.Array.Kind:
var err error
arrayString, err := array.InterfaceToArrayString(colVal, true)
if err != nil {
return nil, err
Expand All @@ -67,12 +66,12 @@ func ParseValue(colVal any, colKind columns.Column, additionalDateFmts []string)

return arrayString, nil
case typing.EDecimal.Kind:
val, isOk := colVal.(*decimal.Decimal)
if !isOk {
return "", fmt.Errorf("colVal is not *decimal.Decimal type")
decimalValue, err := typing.AssertType[*decimal.Decimal](colVal)
if err != nil {
return nil, err
}

return val.String(), nil
return decimalValue.String(), nil
}

return colVal, nil
Expand Down

0 comments on commit c46929f

Please sign in to comment.