From f4a204540f3d56ddd517ed1a345e0ccaa270ec08 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Sun, 21 Apr 2024 15:50:37 -0700 Subject: [PATCH] Add `uppercaseEscapedNames` to `TableIdentifier`s (#477) --- clients/bigquery/bigquery.go | 6 +++--- clients/bigquery/bigquery_test.go | 2 +- clients/bigquery/tableid.go | 22 +++++++++++++------- clients/bigquery/tableid_test.go | 22 ++++++++------------ clients/mssql/staging.go | 2 +- clients/mssql/store.go | 2 +- clients/mssql/store_test.go | 4 ++-- clients/mssql/tableid.go | 15 ++++++------- clients/mssql/tableid_test.go | 22 ++++++++------------ clients/redshift/redshift.go | 6 +++--- clients/redshift/redshift_test.go | 2 +- clients/redshift/staging.go | 2 +- clients/redshift/tableid.go | 15 ++++++------- clients/redshift/tableid_test.go | 22 ++++++++------------ clients/redshift/writes.go | 5 +++-- clients/s3/s3.go | 2 +- clients/s3/tableid.go | 2 +- clients/s3/tableid_test.go | 3 +-- clients/shared/merge.go | 4 ++-- clients/shared/table_config.go | 2 +- clients/shared/table_config_test.go | 2 +- clients/snowflake/snowflake.go | 7 +++---- clients/snowflake/snowflake_test.go | 4 ++-- clients/snowflake/staging.go | 2 +- clients/snowflake/staging_test.go | 6 +++--- clients/snowflake/tableid.go | 22 +++++++++++++------- clients/snowflake/tableid_test.go | 22 ++++++++------------ lib/destination/ddl/ddl.go | 2 +- lib/destination/ddl/ddl_alter_delete_test.go | 6 +++--- lib/destination/ddl/ddl_bq_test.go | 12 +++++------ lib/destination/ddl/ddl_create_table_test.go | 12 +++++------ lib/destination/ddl/ddl_sflk_test.go | 10 ++++----- lib/destination/ddl/ddl_temp_test.go | 4 ++-- lib/destination/types/types.go | 2 +- 34 files changed, 136 insertions(+), 139 deletions(-) diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 8558e8199..b370975e7 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -80,7 +80,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo } func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier { - return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table) + return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table, s.ShouldUppercaseEscapedNames()) } func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) { @@ -135,7 +135,7 @@ func tableRelName(fqName string) (string, error) { func (s *Store) putTable(ctx context.Context, dataset string, tableID types.TableIdentifier, rows []*Row) error { // TODO: [tableID] has [Dataset] on it, don't need to pass it along. - tableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + tableName := tableID.FullyQualifiedName() // TODO: Can probably do `tableName := tableID.Table()` here. relTableName, err := tableRelName(tableName) if err != nil { @@ -157,7 +157,7 @@ func (s *Store) putTable(ctx context.Context, dataset string, tableID types.Tabl } func (s *Store) Dedupe(tableID types.TableIdentifier) error { - fqTableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + fqTableName := tableID.FullyQualifiedName() _, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName)) return err } diff --git a/clients/bigquery/bigquery_test.go b/clients/bigquery/bigquery_test.go index 051e551d7..c6b6d25c9 100644 --- a/clients/bigquery/bigquery_test.go +++ b/clients/bigquery/bigquery_test.go @@ -47,6 +47,6 @@ func TestTempTableName(t *testing.T) { store := &Store{config: config.Config{BigQuery: &config.BigQuery{ProjectID: "123454321"}}} tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames()) + tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName() assert.Equal(t, "`123454321`.`db`.table___artie_sUfFiX", trimTTL(tempTableName)) } diff --git a/clients/bigquery/tableid.go b/clients/bigquery/tableid.go index ad324f447..f63429163 100644 --- a/clients/bigquery/tableid.go +++ b/clients/bigquery/tableid.go @@ -9,13 +9,19 @@ import ( ) type TableIdentifier struct { - projectID string - dataset string - table string + projectID string + dataset string + table string + uppercaseEscapedNames bool } -func NewTableIdentifier(projectID, dataset, table string) TableIdentifier { - return TableIdentifier{projectID: projectID, dataset: dataset, table: table} +func NewTableIdentifier(projectID, dataset, table string, uppercaseEscapedNames bool) TableIdentifier { + return TableIdentifier{ + projectID: projectID, + dataset: dataset, + table: table, + uppercaseEscapedNames: uppercaseEscapedNames, + } } func (ti TableIdentifier) ProjectID() string { @@ -31,16 +37,16 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.projectID, ti.dataset, table) + return NewTableIdentifier(ti.projectID, ti.dataset, table, ti.uppercaseEscapedNames) } -func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string { +func (ti TableIdentifier) FullyQualifiedName() string { // The fully qualified name for BigQuery is: project_id.dataset.tableName. // We are escaping the project_id and dataset because there could be special characters. return fmt.Sprintf( "`%s`.`%s`.%s", ti.projectID, ti.dataset, - sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), + sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), ) } diff --git a/clients/bigquery/tableid_test.go b/clients/bigquery/tableid_test.go index 1e989c5a8..154be275c 100644 --- a/clients/bigquery/tableid_test.go +++ b/clients/bigquery/tableid_test.go @@ -7,26 +7,22 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("project", "dataset", "foo") + tableID := NewTableIdentifier("project", "dataset", "foo", true) tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "project", typedTableID2.ProjectID()) assert.Equal(t, "dataset", typedTableID2.Dataset()) assert.Equal(t, "bar", tableID2.Table()) + assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { - { - // Table name that does not need escaping: - tableID := NewTableIdentifier("project", "dataset", "foo") - assert.Equal(t, "`project`.`dataset`.foo", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "`project`.`dataset`.foo", tableID.FullyQualifiedName(true), "escaped + upper") - } - { - // Table name that needs escaping: - tableID := NewTableIdentifier("project", "dataset", "table") - assert.Equal(t, "`project`.`dataset`.`table`", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "`project`.`dataset`.`TABLE`", tableID.FullyQualifiedName(true), "escaped + upper") - } + // Table name that does not need escaping: + assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", false).FullyQualifiedName()) + assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", true).FullyQualifiedName()) + + // Table name that needs escaping: + assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table", false).FullyQualifiedName()) + assert.Equal(t, "`project`.`dataset`.`TABLE`", NewTableIdentifier("project", "dataset", "table", true).FullyQualifiedName()) } diff --git a/clients/mssql/staging.go b/clients/mssql/staging.go index 93b1facc9..33d69f994 100644 --- a/clients/mssql/staging.go +++ b/clients/mssql/staging.go @@ -13,7 +13,7 @@ import ( ) func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error { - tempTableName := tempTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + tempTableName := tempTableID.FullyQualifiedName() if createTempTable { tempAlterTableArgs := ddl.AlterTableArgs{ diff --git a/clients/mssql/store.go b/clients/mssql/store.go index f2ab47a4c..98490bb3b 100644 --- a/clients/mssql/store.go +++ b/clients/mssql/store.go @@ -49,7 +49,7 @@ func (s *Store) Append(tableData *optimization.TableData) error { // specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name. func (s *Store) specificIdentifierFor(topicConfig kafkalib.TopicConfig, table string) TableIdentifier { - return NewTableIdentifier(getSchema(topicConfig.Schema), table) + return NewTableIdentifier(getSchema(topicConfig.Schema), table, s.ShouldUppercaseEscapedNames()) } // IdentifierFor returns a generic [types.TableIdentifier] interface for a [TopicConfig] + table name. diff --git a/clients/mssql/store_test.go b/clients/mssql/store_test.go index 4c8ec0a2a..0a9096084 100644 --- a/clients/mssql/store_test.go +++ b/clients/mssql/store_test.go @@ -28,14 +28,14 @@ func TestTempTableName(t *testing.T) { // Schema is "schema": tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames()) + tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName() assert.Equal(t, "schema.table___artie_sUfFiX", trimTTL(tempTableName)) } { // Schema is "public" -> "dbo": tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "public"}, "table") tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames()) + tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName() assert.Equal(t, "dbo.table___artie_sUfFiX", trimTTL(tempTableName)) } } diff --git a/clients/mssql/tableid.go b/clients/mssql/tableid.go index 9a5d5bf37..9600d62b9 100644 --- a/clients/mssql/tableid.go +++ b/clients/mssql/tableid.go @@ -9,12 +9,13 @@ import ( ) type TableIdentifier struct { - schema string - table string + schema string + table string + uppercaseEscapedNames bool } -func NewTableIdentifier(schema, table string) TableIdentifier { - return TableIdentifier{schema: schema, table: table} +func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier { + return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames} } func (ti TableIdentifier) Schema() string { @@ -26,13 +27,13 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.schema, table) + return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames) } -func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string { +func (ti TableIdentifier) FullyQualifiedName() string { return fmt.Sprintf( "%s.%s", ti.schema, - sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), + sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), ) } diff --git a/clients/mssql/tableid_test.go b/clients/mssql/tableid_test.go index e0b8a0d45..65c70c290 100644 --- a/clients/mssql/tableid_test.go +++ b/clients/mssql/tableid_test.go @@ -7,25 +7,21 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("schema", "foo") + tableID := NewTableIdentifier("schema", "foo", true) tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "schema", typedTableID2.Schema()) assert.Equal(t, "bar", tableID2.Table()) + assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { - { - // Table name that does not need escaping: - tableID := NewTableIdentifier("schema", "foo") - assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(true), "escaped + upper") - } - { - // Table name that needs escaping: - tableID := NewTableIdentifier("schema", "table") - assert.Equal(t, `schema."table"`, tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, `schema."TABLE"`, tableID.FullyQualifiedName(true), "escaped + upper") - } + // Table name that does not need escaping: + assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName()) + assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName()) + + // Table name that needs escaping: + assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName()) + assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName()) } diff --git a/clients/redshift/redshift.go b/clients/redshift/redshift.go index cd942f223..81dcab916 100644 --- a/clients/redshift/redshift.go +++ b/clients/redshift/redshift.go @@ -30,7 +30,7 @@ type Store struct { } func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier { - return NewTableIdentifier(topicConfig.Schema, table) + return NewTableIdentifier(topicConfig.Schema, table, s.ShouldUppercaseEscapedNames()) } func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap { @@ -98,8 +98,8 @@ WHERE } func (s *Store) Dedupe(tableID types.TableIdentifier) error { - fqTableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) - stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + fqTableName := tableID.FullyQualifiedName() + stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName() query := fmt.Sprintf(` CREATE TABLE %s AS SELECT DISTINCT * FROM %s; diff --git a/clients/redshift/redshift_test.go b/clients/redshift/redshift_test.go index 9fbdcb769..63dff6f64 100644 --- a/clients/redshift/redshift_test.go +++ b/clients/redshift/redshift_test.go @@ -25,6 +25,6 @@ func TestTempTableName(t *testing.T) { tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") tableID := (&Store{}).IdentifierFor(tableData.TopicConfig(), tableData.Name()) - tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(false) + tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName() assert.Equal(t, "schema.table___artie_sUfFiX", trimTTL(tempTableName)) } diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index 781ff8e09..2476bfddb 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -17,7 +17,7 @@ import ( ) func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, _ bool) error { - tempTableName := tempTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + tempTableName := tempTableID.FullyQualifiedName() // Redshift always creates a temporary table. tempAlterTableArgs := ddl.AlterTableArgs{ diff --git a/clients/redshift/tableid.go b/clients/redshift/tableid.go index 07c522055..5d21ef286 100644 --- a/clients/redshift/tableid.go +++ b/clients/redshift/tableid.go @@ -9,12 +9,13 @@ import ( ) type TableIdentifier struct { - schema string - table string + schema string + table string + uppercaseEscapedNames bool } -func NewTableIdentifier(schema, table string) TableIdentifier { - return TableIdentifier{schema: schema, table: table} +func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier { + return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames} } func (ti TableIdentifier) Schema() string { @@ -26,15 +27,15 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.schema, table) + return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames) } -func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string { +func (ti TableIdentifier) FullyQualifiedName() string { // Redshift is Postgres compatible, so when establishing a connection, we'll specify a database. // Thus, we only need to specify schema and table name here. return fmt.Sprintf( "%s.%s", ti.schema, - sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}), + sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}), ) } diff --git a/clients/redshift/tableid_test.go b/clients/redshift/tableid_test.go index 14beb97ee..373b31b43 100644 --- a/clients/redshift/tableid_test.go +++ b/clients/redshift/tableid_test.go @@ -7,25 +7,21 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("schema", "foo") + tableID := NewTableIdentifier("schema", "foo", true) tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "schema", typedTableID2.Schema()) assert.Equal(t, "bar", tableID2.Table()) + assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { - { - // Table name that does not need escaping: - tableID := NewTableIdentifier("schema", "foo") - assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(true), "escaped + upper") - } - { - // Table name that needs escaping: - tableID := NewTableIdentifier("schema", "table") - assert.Equal(t, `schema."table"`, tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, `schema."TABLE"`, tableID.FullyQualifiedName(true), "escaped + upper") - } + // Table name that does not need escaping: + assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName()) + assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName()) + + // Table name that needs escaping: + assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName()) + assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName()) } diff --git a/clients/redshift/writes.go b/clients/redshift/writes.go index aa1d31964..167e34da2 100644 --- a/clients/redshift/writes.go +++ b/clients/redshift/writes.go @@ -14,12 +14,13 @@ func (s *Store) Append(tableData *optimization.TableData) error { // Redshift is slightly different, we'll load and create the temporary table via shared.Append // Then, we'll invoke `ALTER TABLE target APPEND FROM staging` to combine the diffs. temporaryTableID := shared.TempTableID(tableID, tableData.TempTableSuffix()) - temporaryTableName := temporaryTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) if err := shared.Append(s, tableData, s.config, types.AppendOpts{TempTableID: temporaryTableID}); err != nil { return err } - _, err := s.Exec(fmt.Sprintf(`ALTER TABLE %s APPEND FROM %s;`, tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()), temporaryTableName)) + _, err := s.Exec( + fmt.Sprintf(`ALTER TABLE %s APPEND FROM %s;`, tableID.FullyQualifiedName(), temporaryTableID.FullyQualifiedName()), + ) return err } diff --git a/clients/s3/s3.go b/clients/s3/s3.go index bed9539a3..6e85c429f 100644 --- a/clients/s3/s3.go +++ b/clients/s3/s3.go @@ -59,7 +59,7 @@ func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) ty // > optionalPrefix/fullyQualifiedTableName/YYYY-MM-DD func (s *Store) ObjectPrefix(tableData *optimization.TableData) string { tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - fqTableName := tableID.FullyQualifiedName(false) + fqTableName := tableID.FullyQualifiedName() yyyyMMDDFormat := tableData.LatestCDCTs.Format(ext.PostgresDateFormat) if len(s.config.S3.OptionalPrefix) > 0 { diff --git a/clients/s3/tableid.go b/clients/s3/tableid.go index bf17fa293..c9e692600 100644 --- a/clients/s3/tableid.go +++ b/clients/s3/tableid.go @@ -32,7 +32,7 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { return NewTableIdentifier(ti.database, ti.schema, table) } -func (ti TableIdentifier) FullyQualifiedName(_ bool) string { +func (ti TableIdentifier) FullyQualifiedName() string { // S3 should be db.schema.tableName, but we don't need to escape, since it's not a SQL db. return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, ti.table) } diff --git a/clients/s3/tableid_test.go b/clients/s3/tableid_test.go index bc5162948..fd352be18 100644 --- a/clients/s3/tableid_test.go +++ b/clients/s3/tableid_test.go @@ -19,6 +19,5 @@ func TestTableIdentifier_WithTable(t *testing.T) { func TestTableIdentifier_FullyQualifiedName(t *testing.T) { // S3 doesn't escape the table name. tableID := NewTableIdentifier("database", "schema", "table") - assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName(true), "escaped + upper") + assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName()) } diff --git a/clients/shared/merge.go b/clients/shared/merge.go index aa2617d22..e5b53fdd9 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -35,7 +35,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg tableData.TopicConfig().IncludeDatabaseUpdatedAt, tableData.Mode()) tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - fqName := tableID.FullyQualifiedName(dwh.ShouldUppercaseEscapedNames()) + fqName := tableID.FullyQualifiedName() createAlterTableArgs := ddl.AlterTableArgs{ Dwh: dwh, Tc: tableConfig, @@ -73,7 +73,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg tableConfig.AuditColumnsToDelete(srcKeysMissing) tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...) temporaryTableID := TempTableID(dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()), tableData.TempTableSuffix()) - temporaryTableName := temporaryTableID.FullyQualifiedName(dwh.ShouldUppercaseEscapedNames()) + temporaryTableName := temporaryTableID.FullyQualifiedName() if err = dwh.PrepareTemporaryTable(tableData, tableConfig, temporaryTableID, types.AdditionalSettings{}, true); err != nil { return fmt.Errorf("failed to prepare temporary table: %w", err) } diff --git a/clients/shared/table_config.go b/clients/shared/table_config.go index e82971d93..e86393acf 100644 --- a/clients/shared/table_config.go +++ b/clients/shared/table_config.go @@ -41,7 +41,7 @@ func (g GetTableCfgArgs) ShouldParseComment(comment string) bool { } func (g GetTableCfgArgs) GetTableConfig() (*types.DwhTableConfig, error) { - fqName := g.TableID.FullyQualifiedName(g.Dwh.ShouldUppercaseEscapedNames()) + fqName := g.TableID.FullyQualifiedName() // Check if it already exists in cache tableConfig := g.ConfigMap.TableConfig(fqName) diff --git a/clients/shared/table_config_test.go b/clients/shared/table_config_test.go index 7d2bfcba2..59ea01de0 100644 --- a/clients/shared/table_config_test.go +++ b/clients/shared/table_config_test.go @@ -80,7 +80,7 @@ type MockTableIdentifier struct{ fqName string } func (MockTableIdentifier) Table() string { panic("not implemented") } func (MockTableIdentifier) WithTable(table string) types.TableIdentifier { panic("not implemented") } -func (m MockTableIdentifier) FullyQualifiedName(_ bool) string { return m.fqName } +func (m MockTableIdentifier) FullyQualifiedName() string { return m.fqName } func TestGetTableConfig(t *testing.T) { // Return early because table is found in configMap. diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 1c2ed2100..1c3152310 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -32,17 +32,16 @@ const ( ) func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier { - return NewTableIdentifier(topicConfig.Database, topicConfig.Schema, table) + return NewTableIdentifier(topicConfig.Database, topicConfig.Schema, table, s.ShouldUppercaseEscapedNames()) } func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) { tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - fqName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) return shared.GetTableCfgArgs{ Dwh: s, TableID: tableID, ConfigMap: s.configMap, - Query: fmt.Sprintf("DESC TABLE %s;", fqName), + Query: fmt.Sprintf("DESC TABLE %s;", tableID.FullyQualifiedName()), ColumnNameLabel: describeNameCol, ColumnTypeLabel: describeTypeCol, ColumnDescLabel: describeCommentCol, @@ -121,7 +120,7 @@ func (s *Store) reestablishConnection() error { } func (s *Store) Dedupe(tableID types.TableIdentifier) error { - fqTableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + fqTableName := tableID.FullyQualifiedName() _, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName)) return err } diff --git a/clients/snowflake/snowflake_test.go b/clients/snowflake/snowflake_test.go index db31e475d..3a8fa95fc 100644 --- a/clients/snowflake/snowflake_test.go +++ b/clients/snowflake/snowflake_test.go @@ -24,7 +24,7 @@ import ( func (s *SnowflakeTestSuite) fullyQualifiedName(tableData *optimization.TableData) string { tableID := s.stageStore.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - return tableID.FullyQualifiedName(s.stageStore.config.SharedDestinationConfig.UppercaseEscapedNames) + return tableID.FullyQualifiedName() } func (s *SnowflakeTestSuite) TestExecuteMergeNilEdgeCase() { @@ -307,6 +307,6 @@ func TestTempTableName(t *testing.T) { tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table") tableID := (&Store{}).IdentifierFor(tableData.TopicConfig(), tableData.Name()) - tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(false) + tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName() assert.Equal(t, "db.schema.table___artie_sUfFiX", trimTTL(tempTableName)) } diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index 7c359154a..747fe25b5 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -48,7 +48,7 @@ func castColValStaging(colVal any, colKind columns.Column, additionalDateFmts [] } func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error { - tempTableName := tempTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()) + tempTableName := tempTableID.FullyQualifiedName() if createTempTable { tempAlterTableArgs := ddl.AlterTableArgs{ diff --git a/clients/snowflake/staging_test.go b/clients/snowflake/staging_test.go index e8079d930..16e4f2ab0 100644 --- a/clients/snowflake/staging_test.go +++ b/clients/snowflake/staging_test.go @@ -130,12 +130,12 @@ func generateTableData(rows int) (TableIdentifier, *optimization.TableData) { td.InsertRow(key, rowData, false) } - return NewTableIdentifier("database", "schema", randomTableName), td + return NewTableIdentifier("database", "schema", randomTableName, true), td } func (s *SnowflakeTestSuite) TestPrepareTempTable() { tempTableID, tableData := generateTableData(10) - tempTableName := tempTableID.FullyQualifiedName(true) + tempTableName := tempTableID.FullyQualifiedName() s.stageStore.GetConfigMap().AddTableToConfig(tempTableName, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) sflkTc := s.stageStore.GetConfigMap().TableConfig(tempTableName) @@ -170,7 +170,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() { func (s *SnowflakeTestSuite) TestLoadTemporaryTable() { tempTableID, tableData := generateTableData(100) - fp, err := s.stageStore.writeTemporaryTableFile(tableData, tempTableID.FullyQualifiedName(true)) + fp, err := s.stageStore.writeTemporaryTableFile(tableData, tempTableID.FullyQualifiedName()) assert.NoError(s.T(), err) // Read the CSV and confirm. csvfile, err := os.Open(fp) diff --git a/clients/snowflake/tableid.go b/clients/snowflake/tableid.go index 4d1b814b0..8c313d42f 100644 --- a/clients/snowflake/tableid.go +++ b/clients/snowflake/tableid.go @@ -9,13 +9,19 @@ import ( ) type TableIdentifier struct { - database string - schema string - table string + database string + schema string + table string + uppercaseEscapedNames bool } -func NewTableIdentifier(database, schema, table string) TableIdentifier { - return TableIdentifier{database: database, schema: schema, table: table} +func NewTableIdentifier(database, schema, table string, uppercaseEscapedNames bool) TableIdentifier { + return TableIdentifier{ + database: database, + schema: schema, + table: table, + uppercaseEscapedNames: uppercaseEscapedNames, + } } func (ti TableIdentifier) Database() string { @@ -31,14 +37,14 @@ func (ti TableIdentifier) Table() string { } func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { - return NewTableIdentifier(ti.database, ti.schema, table) + return NewTableIdentifier(ti.database, ti.schema, table, ti.uppercaseEscapedNames) } -func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string { +func (ti TableIdentifier) FullyQualifiedName() string { return fmt.Sprintf( "%s.%s.%s", ti.database, ti.schema, - sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}), + sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}), ) } diff --git a/clients/snowflake/tableid_test.go b/clients/snowflake/tableid_test.go index c97efee81..eed38849e 100644 --- a/clients/snowflake/tableid_test.go +++ b/clients/snowflake/tableid_test.go @@ -7,26 +7,22 @@ import ( ) func TestTableIdentifier_WithTable(t *testing.T) { - tableID := NewTableIdentifier("database", "schema", "foo") + tableID := NewTableIdentifier("database", "schema", "foo", true) tableID2 := tableID.WithTable("bar") typedTableID2, ok := tableID2.(TableIdentifier) assert.True(t, ok) assert.Equal(t, "database", typedTableID2.Database()) assert.Equal(t, "schema", typedTableID2.Schema()) assert.Equal(t, "bar", tableID2.Table()) + assert.True(t, typedTableID2.uppercaseEscapedNames) } func TestTableIdentifier_FullyQualifiedName(t *testing.T) { - { - // Table name that does not need escaping: - tableID := NewTableIdentifier("database", "schema", "foo") - assert.Equal(t, "database.schema.foo", tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, "database.schema.foo", tableID.FullyQualifiedName(true), "escaped + upper") - } - { - // Table name that needs escaping: - tableID := NewTableIdentifier("database", "schema", "table") - assert.Equal(t, `database.schema."table"`, tableID.FullyQualifiedName(false), "escaped") - assert.Equal(t, `database.schema."TABLE"`, tableID.FullyQualifiedName(true), "escaped + upper") - } + // Table name that does not need escaping: + assert.Equal(t, "database.schema.foo", NewTableIdentifier("database", "schema", "foo", false).FullyQualifiedName()) + assert.Equal(t, "database.schema.foo", NewTableIdentifier("database", "schema", "foo", true).FullyQualifiedName()) + + // Table name that needs escaping: + assert.Equal(t, `database.schema."table"`, NewTableIdentifier("database", "schema", "table", false).FullyQualifiedName()) + assert.Equal(t, `database.schema."TABLE"`, NewTableIdentifier("database", "schema", "table", true).FullyQualifiedName()) } diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index 1c8ffadac..21554050f 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -134,7 +134,7 @@ func (a AlterTableArgs) AlterTable(cols ...columns.Column) error { colSQLParts = append(colSQLParts, pkStatement) } - fqTableName := a.TableID.FullyQualifiedName(a.Dwh.ShouldUppercaseEscapedNames()) + fqTableName := a.TableID.FullyQualifiedName() var err error if a.CreateTable { diff --git a/lib/destination/ddl/ddl_alter_delete_test.go b/lib/destination/ddl/ddl_alter_delete_test.go index f145eee69..dfd2e6f1a 100644 --- a/lib/destination/ddl/ddl_alter_delete_test.go +++ b/lib/destination/ddl/ddl_alter_delete_test.go @@ -35,13 +35,13 @@ func (d *DDLTestSuite) TestAlterDelete_Complete() { originalColumnLength := len(cols.GetColumns()) bqTableID := d.bigQueryStore.IdentifierFor(td.TopicConfig(), td.Name()) - bqName := bqTableID.FullyQualifiedName(false) + bqName := bqTableID.FullyQualifiedName() redshiftTableID := d.redshiftStore.IdentifierFor(td.TopicConfig(), td.Name()) - redshiftName := redshiftTableID.FullyQualifiedName(false) + redshiftName := redshiftTableID.FullyQualifiedName() snowflakeTableID := d.snowflakeStagesStore.IdentifierFor(td.TopicConfig(), td.Name()) - snowflakeName := snowflakeTableID.FullyQualifiedName(false) + snowflakeName := snowflakeTableID.FullyQualifiedName() // Testing 3 scenarios here // 1. DropDeletedColumns = false, ContainOtherOperations = true, don't delete ever. diff --git a/lib/destination/ddl/ddl_bq_test.go b/lib/destination/ddl/ddl_bq_test.go index e564a3639..b5342d7c6 100644 --- a/lib/destination/ddl/ddl_bq_test.go +++ b/lib/destination/ddl/ddl_bq_test.go @@ -47,7 +47,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { } tableID := d.bigQueryStore.IdentifierFor(td.TopicConfig(), td.Name()) - fqName := tableID.FullyQualifiedName(false) + fqName := tableID.FullyQualifiedName() originalColumnLength := len(cols.GetColumns()) d.bigQueryStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&cols, nil, false, true)) tc := d.bigQueryStore.GetConfigMap().TableConfig(fqName) @@ -107,8 +107,8 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() { } func (d *DDLTestSuite) TestAlterTableAddColumns() { - tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols") - fqName := tableID.FullyQualifiedName(true) + tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true) + fqName := tableID.FullyQualifiedName() ts := time.Now() existingColNameToKindDetailsMap := map[string]typing.KindDetails{ "foo": typing.String, @@ -176,8 +176,8 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() { } func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() { - tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols") - fqName := tableID.FullyQualifiedName(true) + tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true) + fqName := tableID.FullyQualifiedName() ts := time.Now() existingColNameToKindDetailsMap := map[string]typing.KindDetails{ "foo": typing.String, @@ -251,7 +251,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuerySafety() { } tableID := d.bigQueryStore.IdentifierFor(td.TopicConfig(), td.Name()) - fqName := tableID.FullyQualifiedName(false) + fqName := tableID.FullyQualifiedName() originalColumnLength := len(columnNameToKindDetailsMap) d.bigQueryStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&cols, nil, false, false)) tc := d.bigQueryStore.GetConfigMap().TableConfig(fqName) diff --git a/lib/destination/ddl/ddl_create_table_test.go b/lib/destination/ddl/ddl_create_table_test.go index 2a85b883f..fd5bcbc45 100644 --- a/lib/destination/ddl/ddl_create_table_test.go +++ b/lib/destination/ddl/ddl_create_table_test.go @@ -22,12 +22,12 @@ import ( ) func (d *DDLTestSuite) Test_CreateTable() { - bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table") - bqFqName := bqTableID.FullyQualifiedName(true) + bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table", true) + bqFqName := bqTableID.FullyQualifiedName() d.bigQueryStore.GetConfigMap().AddTableToConfig(bqFqName, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) - snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table") - snowflakeFqName := snowflakeTableID.FullyQualifiedName(true) + snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", true) + snowflakeFqName := snowflakeTableID.FullyQualifiedName() d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(snowflakeFqName, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) type dwhToTableConfig struct { @@ -68,7 +68,7 @@ func (d *DDLTestSuite) Test_CreateTable() { assert.Equal(d.T(), 1, dwhTc._fakeStore.ExecCallCount()) query, _ := dwhTc._fakeStore.ExecArgsForCall(0) - assert.Equal(d.T(), query, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", dwhTc._tableID.FullyQualifiedName(true)), query) + assert.Equal(d.T(), query, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", dwhTc._tableID.FullyQualifiedName()), query) assert.Equal(d.T(), false, dwhTc._tableConfig.CreateTable()) } } @@ -116,7 +116,7 @@ func (d *DDLTestSuite) TestCreateTable() { } for index, testCase := range testCases { - tableID := snowflake.NewTableIdentifier("demo", "public", "experiments") + tableID := snowflake.NewTableIdentifier("demo", "public", "experiments", false) fqTable := "demo.public.experiments" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqTable) diff --git a/lib/destination/ddl/ddl_sflk_test.go b/lib/destination/ddl/ddl_sflk_test.go index 9257edd49..8568fc7c8 100644 --- a/lib/destination/ddl/ddl_sflk_test.go +++ b/lib/destination/ddl/ddl_sflk_test.go @@ -31,7 +31,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() { columns.NewColumn("select", typing.String), } - tableID := snowflake.NewTableIdentifier("shop", "public", "complex_columns") + tableID := snowflake.NewTableIdentifier("shop", "public", "complex_columns", true) fqTable := "shop.public.complex_columns" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true)) tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqTable) @@ -67,7 +67,7 @@ func (d *DDLTestSuite) TestAlterIdempotency() { columns.NewColumn("start", typing.String), } - tableID := snowflake.NewTableIdentifier("shop", "public", "orders") + tableID := snowflake.NewTableIdentifier("shop", "public", "orders", true) fqTable := "shop.public.orders" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true)) tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqTable) @@ -99,7 +99,7 @@ func (d *DDLTestSuite) TestAlterTableAdd() { columns.NewColumn("start", typing.String), } - tableID := snowflake.NewTableIdentifier("shop", "public", "orders") + tableID := snowflake.NewTableIdentifier("shop", "public", "orders", true) fqTable := "shop.public.orders" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true)) tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqTable) @@ -143,7 +143,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() { columns.NewColumn("start", typing.String), } - tableID := snowflake.NewTableIdentifier("shop", "public", "users") + tableID := snowflake.NewTableIdentifier("shop", "public", "users", true) fqTable := "shop.public.users" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true)) tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqTable) @@ -203,7 +203,7 @@ func (d *DDLTestSuite) TestAlterTableDelete() { columns.NewColumn("start", typing.String), } - tableID := snowflake.NewTableIdentifier("shop", "public", "users1") + tableID := snowflake.NewTableIdentifier("shop", "public", "users1", true) fqTable := "shop.public.users1" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqTable, types.NewDwhTableConfig(&columns.Columns{}, map[string]time.Time{ diff --git a/lib/destination/ddl/ddl_temp_test.go b/lib/destination/ddl/ddl_temp_test.go index 72ba24579..a8439dafc 100644 --- a/lib/destination/ddl/ddl_temp_test.go +++ b/lib/destination/ddl/ddl_temp_test.go @@ -34,7 +34,7 @@ func (d *DDLTestSuite) TestValidate_AlterTableArgs() { } func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() { - tableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table") + tableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", false) fqName := "mock_dataset.mock_table" d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) snowflakeTc := d.snowflakeStagesStore.GetConfigMap().TableConfig(fqName) @@ -67,7 +67,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() { } func (d *DDLTestSuite) TestCreateTemporaryTable() { - tableID := snowflake.NewTableIdentifier("db", "schema", "tempTableName") + tableID := snowflake.NewTableIdentifier("db", "schema", "tempTableName", false) fqName := "db.schema.tempTableName" // Snowflake Stage d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(fqName, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true)) diff --git a/lib/destination/types/types.go b/lib/destination/types/types.go index 77b7e2d4c..dbd3ca61e 100644 --- a/lib/destination/types/types.go +++ b/lib/destination/types/types.go @@ -53,5 +53,5 @@ type AppendOpts struct { type TableIdentifier interface { Table() string WithTable(table string) TableIdentifier - FullyQualifiedName(uppercaseEscNames bool) string + FullyQualifiedName() string }