From d54201aa4fe6d668dc5f142ec793dc368af269bb Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Thu, 13 Jun 2024 11:59:29 -0700 Subject: [PATCH] WIP --- clients/bigquery/bigquery.go | 33 ++--------- clients/bigquery/storagewrite.go | 83 +++++++++++++++++++++++++-- clients/bigquery/storagewrite_test.go | 54 +++++++++++++++++ lib/debezium/decimal.go | 4 ++ lib/typing/decimal/decimal.go | 4 ++ 5 files changed, 145 insertions(+), 33 deletions(-) diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index a5648e4cc..236be9e00 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -12,7 +12,6 @@ import ( "cloud.google.com/go/bigquery/storage/managedwriter/adapt" _ "github.com/viant/bigquery" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" "github.com/artie-labs/transfer/clients/bigquery/dialect" "github.com/artie-labs/transfer/clients/shared" @@ -167,14 +166,13 @@ func (s *Store) putTableViaLegacyAPI(ctx context.Context, tableID TableIdentifie } func (s *Store) putTableViaStorageWriteAPI(ctx context.Context, bqTableID TableIdentifier, tableData *optimization.TableData) error { + columns := tableData.ReadOnlyInMemoryCols().ValidColumns() + // TODO: Think about whether we want to support batching in this method client := s.GetClient(ctx) defer client.Close() - metadata, err := client.Dataset(bqTableID.Dataset()).Table(bqTableID.Table()).Metadata(ctx) - if err != nil { - return fmt.Errorf("failed to fetch table schema: %w", err) - } - messageDescriptor, err := schemaToMessageDescriptor(metadata.Schema) + + messageDescriptor, err := columnsToMessageDescriptor(columns) if err != nil { return err } @@ -203,7 +201,6 @@ func (s *Store) putTableViaStorageWriteAPI(ctx context.Context, bqTableID TableI defer managedStream.Close() rows := tableData.Rows() - columns := tableData.ReadOnlyInMemoryCols().ValidColumns() encoded := make([][]byte, len(rows)) for i, row := range rows { message, err := rowToMessage(row, columns, *messageDescriptor, s.AdditionalDateFormats()) @@ -273,25 +270,3 @@ func LoadBigQuery(cfg config.Config, _store *db.Store) (*Store, error) { config: cfg, }, nil } - -func schemaToMessageDescriptor(schema bigquery.Schema) (*protoreflect.MessageDescriptor, error) { - for _, field := range schema { - if field.Type == bigquery.JSONFieldType { - field.Type = bigquery.StringFieldType - } - } - - storageSchema, err := adapt.BQSchemaToStorageTableSchema(schema) - if err != nil { - return nil, fmt.Errorf("failed to adapt BQ schema to protocol buffer schema: %w", err) - } - descriptor, err := adapt.StorageSchemaToProto2Descriptor(storageSchema, "root") - if err != nil { - return nil, fmt.Errorf("failed to build protocol buffer descriptor: %w", err) - } - messageDescriptor, ok := descriptor.(protoreflect.MessageDescriptor) - if !ok { - return nil, fmt.Errorf("adapted descriptor is not a message descriptor") - } - return &messageDescriptor, nil -} diff --git a/clients/bigquery/storagewrite.go b/clients/bigquery/storagewrite.go index 1f9b13445..617af9e91 100644 --- a/clients/bigquery/storagewrite.go +++ b/clients/bigquery/storagewrite.go @@ -4,6 +4,8 @@ import ( "fmt" "time" + "cloud.google.com/go/bigquery" + "cloud.google.com/go/bigquery/storage/managedwriter/adapt" "github.com/artie-labs/transfer/lib/array" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/columns" @@ -14,6 +16,76 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +func schemaToMessageDescriptor(schema bigquery.Schema) (*protoreflect.MessageDescriptor, error) { + storageSchema, err := adapt.BQSchemaToStorageTableSchema(schema) + if err != nil { + return nil, fmt.Errorf("failed to adapt BigQuery schema to protocol buffer schema: %w", err) + } + descriptor, err := adapt.StorageSchemaToProto2Descriptor(storageSchema, "root") + if err != nil { + return nil, fmt.Errorf("failed to build protocol buffer descriptor: %w", err) + } + messageDescriptor, ok := descriptor.(protoreflect.MessageDescriptor) + if !ok { + return nil, fmt.Errorf("adapted descriptor is not a message descriptor") + } + return &messageDescriptor, nil +} + +func columnToFieldSchema(column columns.Column) (*bigquery.FieldSchema, error) { + var fieldType bigquery.FieldType + var repeated bool + + switch column.KindDetails.Kind { + case typing.Boolean.Kind: + fieldType = bigquery.BooleanFieldType + case typing.Integer.Kind: + fieldType = bigquery.IntegerFieldType + case typing.Float.Kind: + fieldType = bigquery.FloatFieldType + case typing.String.Kind: + fieldType = bigquery.StringFieldType + case typing.EDecimal.Kind: + fieldType = bigquery.NumericFieldType + case typing.ETime.Kind: + switch column.KindDetails.ExtendedTimeDetails.Type { + case ext.DateKindType: + fieldType = bigquery.DateFieldType + case ext.TimeKindType: + fieldType = bigquery.TimeFieldType + case ext.DateTimeKindType: + fieldType = bigquery.TimestampFieldType + default: + return nil, fmt.Errorf("unsupported extended time details type: %s", column.KindDetails.ExtendedTimeDetails.Type) + } + case typing.Struct.Kind: + fieldType = bigquery.StringFieldType + case typing.Array.Kind: + fieldType = bigquery.StringFieldType + repeated = true + default: + return nil, fmt.Errorf("unsupported column kind: %s", column.KindDetails.Kind) + } + + return &bigquery.FieldSchema{ + Name: column.Name(), + Type: fieldType, + Repeated: repeated, + }, nil +} + +func columnsToMessageDescriptor(cols []columns.Column) (*protoreflect.MessageDescriptor, error) { + fields := make([]*bigquery.FieldSchema, len(cols)) + for i, col := range cols { + field, err := columnToFieldSchema(col) + if err != nil { + return nil, err + } + fields[i] = field + } + return schemaToMessageDescriptor(fields) +} + // From https://cloud.google.com/java/docs/reference/google-cloud-bigquerystorage/latest/com.google.cloud.bigquery.storage.v1.CivilTimeEncoder // And https://cloud.google.com/pubsub/docs/bigquery#date_time_int func encodePacked64TimeMicros(value time.Time) int64 { @@ -49,7 +121,7 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto case typing.Integer.Kind: switch value := value.(type) { case int32: - message.Set(field, protoreflect.ValueOfInt32(value)) + message.Set(field, protoreflect.ValueOfInt64(int64(value))) case int64: message.Set(field, protoreflect.ValueOfInt64(value)) case int: @@ -60,7 +132,7 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto case typing.Float.Kind: switch value := value.(type) { case float32: - message.Set(field, protoreflect.ValueOfFloat32(value)) + message.Set(field, protoreflect.ValueOfFloat64(float64(value))) case float64: message.Set(field, protoreflect.ValueOfFloat64(value)) default: @@ -74,7 +146,11 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto } case typing.EDecimal.Kind: if decimalValue, ok := value.(*decimal.Decimal); ok { - message.Set(field, protoreflect.ValueOf(decimalValue.Value())) + bytes, err := decimalValue.Bytes() + if err != nil { + return nil, err + } + message.Set(field, protoreflect.ValueOfBytes(bytes)) } else { return nil, fmt.Errorf("expected *decimal.Decimal received %T with value %v", decimalValue, decimalValue) } @@ -124,7 +200,6 @@ func rowToMessage(row map[string]any, columns []columns.Column, messageDescripto default: return nil, fmt.Errorf("unsupported column kind: %s", column.KindDetails.Kind) } - } return message, nil } diff --git a/clients/bigquery/storagewrite_test.go b/clients/bigquery/storagewrite_test.go index 53bb6a619..9e3be542c 100644 --- a/clients/bigquery/storagewrite_test.go +++ b/clients/bigquery/storagewrite_test.go @@ -1,10 +1,16 @@ package bigquery import ( + "encoding/json" + "math/big" "testing" "time" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" + "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/encoding/protojson" ) func TestEncodePacked64TimeMicros(t *testing.T) { @@ -19,3 +25,51 @@ func TestEncodePacked64TimeMicros(t *testing.T) { assert.Equal(t, int64(1<<32+1), encodePacked64TimeMicros(epoch.Add(time.Duration(1)*time.Hour+time.Duration(1)*time.Microsecond))) assert.Equal(t, int64(1<<32+1000), encodePacked64TimeMicros(epoch.Add(time.Duration(1)*time.Hour+time.Duration(1)*time.Millisecond))) } + +func TestRowToMessage(t *testing.T) { + row := map[string]any{ + "c_bool": true, + "c_int32": int32(1234), + "c_int64": int32(1234), + "c_float32": float32(1234.567), + "c_float64": float64(1234.567), + "c_numeric": decimal.NewDecimal(nil, 5, big.NewFloat(3.1415926)), + "c_string": "foo bar", + "c_array": []string{"foo", "bar"}, + "c_struct": map[string]any{"baz": []string{"foo", "bar"}}, + } + + columns := []columns.Column{ + columns.NewColumn("c_bool", typing.Boolean), + columns.NewColumn("c_int32", typing.Integer), + columns.NewColumn("c_int64", typing.Integer), + columns.NewColumn("c_float32", typing.Float), + columns.NewColumn("c_float64", typing.Float), + columns.NewColumn("c_numeric", typing.EDecimal), + columns.NewColumn("c_string", typing.String), + columns.NewColumn("c_array", typing.Array), + columns.NewColumn("c_struct", typing.Struct), + } + + desc, err := columnsToMessageDescriptor(columns) + assert.NoError(t, err) + + message, err := rowToMessage(row, columns, *desc, []string{}) + assert.NoError(t, err) + + bytes, err := protojson.Marshal(message) + assert.NoError(t, err) + + var result map[string]any + assert.NoError(t, json.Unmarshal(bytes, &result)) + + assert.Equal(t, map[string]any{ + "cBool": true, + "cFloat32": 1234.5670166015625, + "cFloat64": 1234.567, + "cInt32": "1234", + "cInt64": "1234", + "cString": "foo bar", + "cArray": []any{"foo", "bar"}, + }, result) +} diff --git a/lib/debezium/decimal.go b/lib/debezium/decimal.go index 793c4b814..11c4440e4 100644 --- a/lib/debezium/decimal.go +++ b/lib/debezium/decimal.go @@ -16,6 +16,10 @@ func EncodeDecimal(value string, scale int) ([]byte, error) { } bigFloatValue.Mul(bigFloatValue, new(big.Float).SetInt(scaledValue)) + return EncodeBigFloat(bigFloatValue) +} + +func EncodeBigFloat(bigFloatValue *big.Float) ([]byte, error) { // Extract the scaled integer value. bigIntValue, _ := bigFloatValue.Int(nil) data := bigIntValue.Bytes() diff --git a/lib/typing/decimal/decimal.go b/lib/typing/decimal/decimal.go index dda4b2851..db8a2f710 100644 --- a/lib/typing/decimal/decimal.go +++ b/lib/typing/decimal/decimal.go @@ -55,6 +55,10 @@ func (d *Decimal) String() string { return d.value.Text('f', d.scale) } +func (d *Decimal) Bytes() ([]byte, error) { + return d.value.GobEncode() +} + func (d *Decimal) Value() any { // -1 precision is used for variable scaled decimal // We are opting to emit this as a STRING because the value is technically unbounded (can get to ~1 GB).