diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index c9d51bd94..a246e0ddc 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -7,6 +7,8 @@ import ( "os" "strings" + "github.com/artie-labs/transfer/lib/typing" + "cloud.google.com/go/bigquery" "cloud.google.com/go/bigquery/storage/managedwriter" "cloud.google.com/go/bigquery/storage/managedwriter/adapt" @@ -94,9 +96,9 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo } } - bqTempTableID, ok := tempTableID.(TableIdentifier) - if !ok { - return fmt.Errorf("unable to cast tempTableID to BigQuery TableIdentifier") + bqTempTableID, err := typing.AssertType[TableIdentifier](tempTableID) + if err != nil { + return err } return s.putTable(context.Background(), bqTempTableID, tableData) diff --git a/clients/bigquery/storagewrite.go b/clients/bigquery/storagewrite.go index 67524f2fe..0b432b5c8 100644 --- a/clients/bigquery/storagewrite.go +++ b/clients/bigquery/storagewrite.go @@ -80,10 +80,12 @@ func columnsToMessageDescriptor(cols []columns.Column) (*protoreflect.MessageDes if err != nil { return nil, fmt.Errorf("failed to build proto descriptor: %w", err) } - messageDescriptor, ok := descriptor.(protoreflect.MessageDescriptor) - if !ok { - return nil, fmt.Errorf("proto descriptor is not a message descriptor") + + messageDescriptor, err := typing.AssertType[protoreflect.MessageDescriptor](descriptor) + if err != nil { + return nil, err } + return &messageDescriptor, nil } @@ -154,11 +156,12 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto return nil, fmt.Errorf("expected float32/float64/int32/int64/string received %T with value %v", value, value) } case typing.EDecimal.Kind: - if decimalValue, ok := value.(*decimal.Decimal); ok { - message.Set(field, protoreflect.ValueOfString(decimalValue.String())) - } else { - return nil, fmt.Errorf("expected *decimal.Decimal received %T with value %v", decimalValue, decimalValue) + decimalValue, err := typing.AssertType[*decimal.Decimal](value) + if err != nil { + return nil, err } + + message.Set(field, protoreflect.ValueOfString(decimalValue.String())) case typing.String.Kind: var stringValue string switch castedValue := value.(type) { diff --git a/clients/shared/default_value.go b/clients/shared/default_value.go index e80e17a76..8883c78cc 100644 --- a/clients/shared/default_value.go +++ b/clients/shared/default_value.go @@ -39,12 +39,12 @@ func DefaultValue(column columns.Column, dialect sql.Dialect, additionalDateFmts return sql.QuoteLiteral(extTime.String(column.KindDetails.ExtendedTimeDetails.Format)), nil } case typing.EDecimal.Kind: - val, isOk := column.DefaultValue().(*decimal.Decimal) - if !isOk { - return nil, fmt.Errorf("colVal is not type *decimal.Decimal") + decimalValue, err := typing.AssertType[*decimal.Decimal](column.DefaultValue()) + if err != nil { + return nil, err } - return val.String(), nil + return decimalValue.String(), nil case typing.String.Kind: return sql.QuoteLiteral(fmt.Sprint(column.DefaultValue())), nil } diff --git a/clients/shared/default_value_test.go b/clients/shared/default_value_test.go index 4c2490344..bb4f9c46a 100644 --- a/clients/shared/default_value_test.go +++ b/clients/shared/default_value_test.go @@ -5,6 +5,9 @@ import ( "testing" "time" + "github.com/artie-labs/transfer/lib/numbers" + "github.com/artie-labs/transfer/lib/typing/decimal" + "github.com/stretchr/testify/assert" bigQueryDialect "github.com/artie-labs/transfer/clients/bigquery/dialect" @@ -105,4 +108,22 @@ func TestColumn_DefaultValue(t *testing.T) { assert.Equal(t, expectedValue, actualValue, fmt.Sprintf("%s %s", testCase.name, dialect)) } } + + { + // Decimal value + { + // Type *decimal.Decimal + decimalValue := decimal.NewDecimal(numbers.MustParseDecimal("3.14159")) + col := columns.NewColumnWithDefaultValue("", typing.EDecimal, decimalValue) + value, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil) + assert.NoError(t, err) + assert.Equal(t, "3.14159", value) + } + { + // Wrong type (string) + col := columns.NewColumnWithDefaultValue("", typing.EDecimal, "hello") + _, err := DefaultValue(col, redshiftDialect.RedshiftDialect{}, nil) + assert.ErrorContains(t, err, "expected type *decimal.Decimal, got string") + } + } } diff --git a/lib/crypto/rsa.go b/lib/crypto/rsa.go index 22f031374..45ab19b70 100644 --- a/lib/crypto/rsa.go +++ b/lib/crypto/rsa.go @@ -6,6 +6,8 @@ import ( "encoding/pem" "fmt" "os" + + "github.com/artie-labs/transfer/lib/typing" ) func LoadRSAKey(filePath string) (*rsa.PrivateKey, error) { @@ -28,9 +30,9 @@ func ParseRSAPrivateKey(keyBytes []byte) (*rsa.PrivateKey, error) { return nil, fmt.Errorf("failed to parse private key: %v", err) } - rsaKey, ok := key.(*rsa.PrivateKey) - if !ok { - return nil, fmt.Errorf("not an RSA private key, rather: %T", key) + rsaKey, err := typing.AssertType[*rsa.PrivateKey](key) + if err != nil { + return nil, err } return rsaKey, nil diff --git a/lib/parquetutil/parse_values.go b/lib/parquetutil/parse_values.go index 63db680c3..aa57df883 100644 --- a/lib/parquetutil/parse_values.go +++ b/lib/parquetutil/parse_values.go @@ -55,7 +55,6 @@ func ParseValue(colVal any, colKind columns.Column, additionalDateFmts []string) } } case typing.Array.Kind: - var err error arrayString, err := array.InterfaceToArrayString(colVal, true) if err != nil { return nil, err @@ -67,12 +66,12 @@ func ParseValue(colVal any, colKind columns.Column, additionalDateFmts []string) return arrayString, nil case typing.EDecimal.Kind: - val, isOk := colVal.(*decimal.Decimal) - if !isOk { - return "", fmt.Errorf("colVal is not *decimal.Decimal type") + decimalValue, err := typing.AssertType[*decimal.Decimal](colVal) + if err != nil { + return nil, err } - return val.String(), nil + return decimalValue.String(), nil } return colVal, nil