From 38259473cd7e57141b0142a7a42c629b27aad2c2 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Mon, 22 Apr 2024 11:43:53 -0700 Subject: [PATCH] WIP --- clients/bigquery/bigquery.go | 7 ++++-- clients/bigquery/merge_test.go | 2 +- clients/bigquery/tableid.go | 20 +++++++-------- clients/bigquery/tableid_test.go | 9 +++---- clients/mssql/store.go | 7 ++++-- clients/mssql/tableid.go | 13 +++++----- clients/mssql/tableid_test.go | 9 +++---- clients/shared/append.go | 3 ++- clients/shared/merge.go | 8 +++--- clients/shared/utils.go | 2 +- lib/destination/ddl/ddl_bq_test.go | 4 +-- lib/destination/ddl/ddl_create_table_test.go | 2 +- lib/destination/ddl/ddl_temp_test.go | 2 +- lib/sql/escape.go | 26 ++++++++++---------- lib/sql/escape_test.go | 2 +- 15 files changed, 57 insertions(+), 59 deletions(-) diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index b370975e7..a9207c548 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -80,7 +80,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo } func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier { - return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table, s.ShouldUppercaseEscapedNames()) + return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table) } func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) { @@ -112,7 +112,10 @@ func (s *Store) Label() constants.DestinationKind { } func (s *Store) ShouldUppercaseEscapedNames() bool { - return s.config.SharedDestinationConfig.UppercaseEscapedNames + if s.config.SharedDestinationConfig.UppercaseEscapedNames { + panic("UppercaseEscapedNames is not supported for BigQuery") + } + return false } func (s *Store) GetClient(ctx context.Context) *bigquery.Client { diff --git a/clients/bigquery/merge_test.go b/clients/bigquery/merge_test.go index ee48c8dea..e531de9e0 100644 --- a/clients/bigquery/merge_test.go +++ b/clients/bigquery/merge_test.go @@ -9,7 +9,7 @@ import ( ) func (b *BigQueryTestSuite) TestBackfillColumn() { - tableID := NewTableIdentifier("db", "public", "tableName", false) + tableID := NewTableIdentifier("db", "public", "tableName") type _testCase struct { name string col columns.Column diff --git a/clients/bigquery/tableid.go b/clients/bigquery/tableid.go index 0a651a75c..ca3bcb53f 100644 --- a/clients/bigquery/tableid.go +++ b/clients/bigquery/tableid.go @@ -9,18 +9,16 @@ import ( ) type TableIdentifier struct { - projectID string - dataset string - table string - uppercaseEscapedNames bool + projectID string + dataset string + table string } -func NewTableIdentifier(projectID, dataset, table string, uppercaseEscapedNames bool) TableIdentifier { +func NewTableIdentifier(projectID, dataset, table string) TableIdentifier { return TableIdentifier{ - projectID: projectID, - dataset: dataset, - table: table, - uppercaseEscapedNames: uppercaseEscapedNames, + projectID: projectID, + dataset: dataset, + table: table, } } @@ -37,7 +35,7 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.projectID, ti.dataset, table, ti.uppercaseEscapedNames) + return NewTableIdentifier(ti.projectID, ti.dataset, table) } func (ti TableIdentifier) FullyQualifiedName() string { @@ -47,6 +45,6 @@ func (ti TableIdentifier) FullyQualifiedName() string { "`%s`.`%s`.%s", ti.projectID, ti.dataset, - sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), + sql.EscapeNameIfNecessary(ti.table, false, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), ) } diff --git a/clients/bigquery/tableid_test.go b/clients/bigquery/tableid_test.go index 154be275c..c60b1fea5 100644 --- a/clients/bigquery/tableid_test.go +++ b/clients/bigquery/tableid_test.go @@ -7,22 +7,19 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("project", "dataset", "foo", true) + tableID := NewTableIdentifier("project", "dataset", "foo") tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "project", typedTableID2.ProjectID()) assert.Equal(t, "dataset", typedTableID2.Dataset()) assert.Equal(t, "bar", tableID2.Table()) - assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { // Table name that does not need escaping: - assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", false).FullyQualifiedName()) - assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", true).FullyQualifiedName()) + assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo").FullyQualifiedName()) // Table name that needs escaping: - assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table", false).FullyQualifiedName()) - assert.Equal(t, "`project`.`dataset`.`TABLE`", NewTableIdentifier("project", "dataset", "table", true).FullyQualifiedName()) + assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table").FullyQualifiedName()) } diff --git a/clients/mssql/store.go b/clients/mssql/store.go index 98490bb3b..b08c8bd14 100644 --- a/clients/mssql/store.go +++ b/clients/mssql/store.go @@ -35,7 +35,10 @@ func (s *Store) Label() constants.DestinationKind { } func (s *Store) ShouldUppercaseEscapedNames() bool { - return s.config.SharedDestinationConfig.UppercaseEscapedNames + if s.config.SharedDestinationConfig.UppercaseEscapedNames { + panic("UppercaseEscapedNames is not supported for MSSQL") + } + return false } func (s *Store) Merge(tableData *optimization.TableData) error { @@ -49,7 +52,7 @@ func (s *Store) Append(tableData *optimization.TableData) error { // specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name. func (s *Store) specificIdentifierFor(topicConfig kafkalib.TopicConfig, table string) TableIdentifier { - return NewTableIdentifier(getSchema(topicConfig.Schema), table, s.ShouldUppercaseEscapedNames()) + return NewTableIdentifier(getSchema(topicConfig.Schema), table) } // IdentifierFor returns a generic [types.TableIdentifier] interface for a [TopicConfig] + table name. diff --git a/clients/mssql/tableid.go b/clients/mssql/tableid.go index 4c44b0c0d..37531185a 100644 --- a/clients/mssql/tableid.go +++ b/clients/mssql/tableid.go @@ -9,13 +9,12 @@ import ( ) type TableIdentifier struct { - schema string - table string - uppercaseEscapedNames bool + schema string + table string } -func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier { - return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames} +func NewTableIdentifier(schema, table string) TableIdentifier { + return TableIdentifier{schema: schema, table: table} } func (ti TableIdentifier) Schema() string { @@ -27,13 +26,13 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames) + return NewTableIdentifier(ti.schema, table) } func (ti TableIdentifier) FullyQualifiedName() string { return fmt.Sprintf( "%s.%s", ti.schema, - sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), + sql.EscapeNameIfNecessary(ti.table, false, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), ) } diff --git a/clients/mssql/tableid_test.go b/clients/mssql/tableid_test.go index 65c70c290..76c69f54a 100644 --- a/clients/mssql/tableid_test.go +++ b/clients/mssql/tableid_test.go @@ -7,21 +7,18 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("schema", "foo", true) + tableID := NewTableIdentifier("schema", "foo") tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "schema", typedTableID2.Schema()) assert.Equal(t, "bar", tableID2.Table()) - assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { // Table name that does not need escaping: - assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName()) - assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName()) + assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo").FullyQualifiedName()) // Table name that needs escaping: - assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName()) - assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName()) + assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table").FullyQualifiedName()) } diff --git a/clients/shared/append.go b/clients/shared/append.go index 5437efb43..07d38bb74 100644 --- a/clients/shared/append.go +++ b/clients/shared/append.go @@ -9,6 +9,7 @@ import ( "github.com/artie-labs/transfer/lib/destination/ddl" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/optimization" + "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing/columns" ) @@ -35,7 +36,7 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, cf CreateTable: tableConfig.CreateTable(), ColumnOp: constants.Add, CdcTime: tableData.LatestCDCTs, - UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames, + UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()), Mode: tableData.Mode(), } diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 6b1533659..735bd6ac5 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -42,7 +42,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg CreateTable: tableConfig.CreateTable(), ColumnOp: constants.Add, CdcTime: tableData.LatestCDCTs, - UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames, + UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()), Mode: tableData.Mode(), } @@ -61,7 +61,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg ColumnOp: constants.Delete, ContainOtherOperations: tableData.ContainOtherOperations(), CdcTime: tableData.LatestCDCTs, - UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames, + UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()), Mode: tableData.Mode(), } @@ -123,11 +123,11 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg TableID: tableID, SubQuery: subQuery, IdempotentKey: tableData.TopicConfig().IdempotentKey, - PrimaryKeys: tableData.PrimaryKeys(cfg.SharedDestinationConfig.UppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: dwh.Label()}), + PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), &sql.NameArgs{Escape: true, DestKind: dwh.Label()}), Columns: tableData.ReadOnlyInMemoryCols(), SoftDelete: tableData.TopicConfig().SoftDelete, DestKind: dwh.Label(), - UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames, + UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()), ContainsHardDeletes: ptr.ToBool(tableData.ContainsHardDeletes()), } diff --git a/clients/shared/utils.go b/clients/shared/utils.go index 19f13013d..dcc7bbe0f 100644 --- a/clients/shared/utils.go +++ b/clients/shared/utils.go @@ -32,7 +32,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col return fmt.Errorf("failed to escape default value: %w", err) } - uppercaseEscNames := cfg.SharedDestinationConfig.UppercaseEscapedNames + uppercaseEscNames := dwh.ShouldUppercaseEscapedNames() escapedCol := column.Name(uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: dwh.Label()}) // TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause. diff --git a/lib/destination/ddl/ddl_bq_test.go b/lib/destination/ddl/ddl_bq_test.go index 2f11ffbb9..bdac9a6a4 100644 --- a/lib/destination/ddl/ddl_bq_test.go +++ b/lib/destination/ddl/ddl_bq_test.go @@ -107,7 +107,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { } func (d *DDLTestSuite) TestAlterTableAddColumns() { - tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true) + tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols") fqName := tableID.FullyQualifiedName() ts := time.Now() existingColNameToKindDetailsMap := map[string]typing.KindDetails{ @@ -176,7 +176,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { } func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { - tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true) + tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols") fqName := tableID.FullyQualifiedName() ts := time.Now() existingColNameToKindDetailsMap := map[string]typing.KindDetails{ diff --git a/lib/destination/ddl/ddl_create_table_test.go b/lib/destination/ddl/ddl_create_table_test.go index 0dda6057c..a9ec46fcd 100644 --- a/lib/destination/ddl/ddl_create_table_test.go +++ b/lib/destination/ddl/ddl_create_table_test.go @@ -22,7 +22,7 @@ import ( ) func (d *DDLTestSuite) Test_CreateTable() { - bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table", true) + bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table") d.bigQueryStore.GetConfigMap().AddTableToConfig(bqTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", true) diff --git a/lib/destination/ddl/ddl_temp_test.go b/lib/destination/ddl/ddl_temp_test.go index 0abca55cc..f4e80c2a5 100644 --- a/lib/destination/ddl/ddl_temp_test.go +++ b/lib/destination/ddl/ddl_temp_test.go @@ -96,7 +96,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() { } { // BigQuery - tableID := bigquery.NewTableIdentifier("db", "schema", "tempTableName", false) + tableID := bigquery.NewTableIdentifier("db", "schema", "tempTableName") d.bigQueryStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) bqTc := d.bigQueryStore.GetConfigMap().TableConfig(tableID) args := ddl.AlterTableArgs{ diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 925de75dc..f640e7e76 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -17,19 +17,7 @@ type NameArgs struct { // symbolsToEscape are additional keywords that we need to escape var symbolsToEscape = []string{":"} -func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string { - if args == nil || !args.Escape { - return name - } - - if needsEscaping(name, args.DestKind) { - return escapeName(name, uppercaseEscNames, args.DestKind) - } - - return name -} - -func needsEscaping(name string, destKind constants.DestinationKind) bool { +func needsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool { var reservedKeywords []string if destKind == constants.Redshift { reservedKeywords = constants.RedshiftReservedKeywords @@ -61,6 +49,18 @@ func needsEscaping(name string, destKind constants.DestinationKind) bool { return needsEscaping } +func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string { + if args == nil || !args.Escape { + return name + } + + if needsEscaping(name, uppercaseEscNames, args.DestKind) { + return escapeName(name, uppercaseEscNames, args.DestKind) + } + + return name +} + func escapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { if uppercaseEscNames { name = strings.ToUpper(name) diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go index 7b70201d6..e16473306 100644 --- a/lib/sql/escape_test.go +++ b/lib/sql/escape_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestEscapeNameIfNecessary(t *testing.T) { +func TestEscapeName(t *testing.T) { type _testCase struct { name string nameToEscape string