Skip to content

Commit

Permalink
Add uppercaseEscapedNames to TableIdentifiers (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Apr 21, 2024
1 parent a0b18a3 commit f4a2045
Show file tree
Hide file tree
Showing 34 changed files with 136 additions and 139 deletions.
6 changes: 3 additions & 3 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier {
return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table)
return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table, s.ShouldUppercaseEscapedNames())
}

func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) {
Expand Down Expand Up @@ -135,7 +135,7 @@ func tableRelName(fqName string) (string, error) {

func (s *Store) putTable(ctx context.Context, dataset string, tableID types.TableIdentifier, rows []*Row) error {
// TODO: [tableID] has [Dataset] on it, don't need to pass it along.
tableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
tableName := tableID.FullyQualifiedName()
// TODO: Can probably do `tableName := tableID.Table()` here.
relTableName, err := tableRelName(tableName)
if err != nil {
Expand All @@ -157,7 +157,7 @@ func (s *Store) putTable(ctx context.Context, dataset string, tableID types.Tabl
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
fqTableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
fqTableName := tableID.FullyQualifiedName()
_, err := s.Exec(fmt.Sprintf("CREATE OR REPLACE TABLE %s AS SELECT DISTINCT * FROM %s", fqTableName, fqTableName))
return err
}
Expand Down
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ func TestTempTableName(t *testing.T) {
store := &Store{config: config.Config{BigQuery: &config.BigQuery{ProjectID: "123454321"}}}
tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table")
tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName()
assert.Equal(t, "`123454321`.`db`.table___artie_sUfFiX", trimTTL(tempTableName))
}
22 changes: 14 additions & 8 deletions clients/bigquery/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ import (
)

type TableIdentifier struct {
projectID string
dataset string
table string
projectID string
dataset string
table string
uppercaseEscapedNames bool
}

func NewTableIdentifier(projectID, dataset, table string) TableIdentifier {
return TableIdentifier{projectID: projectID, dataset: dataset, table: table}
func NewTableIdentifier(projectID, dataset, table string, uppercaseEscapedNames bool) TableIdentifier {
return TableIdentifier{
projectID: projectID,
dataset: dataset,
table: table,
uppercaseEscapedNames: uppercaseEscapedNames,
}
}

func (ti TableIdentifier) ProjectID() string {
Expand All @@ -31,16 +37,16 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.projectID, ti.dataset, table)
return NewTableIdentifier(ti.projectID, ti.dataset, table, ti.uppercaseEscapedNames)
}

func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string {
func (ti TableIdentifier) FullyQualifiedName() string {
// The fully qualified name for BigQuery is: project_id.dataset.tableName.
// We are escaping the project_id and dataset because there could be special characters.
return fmt.Sprintf(
"`%s`.`%s`.%s",
ti.projectID,
ti.dataset,
sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
)
}
22 changes: 9 additions & 13 deletions clients/bigquery/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,22 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("project", "dataset", "foo")
tableID := NewTableIdentifier("project", "dataset", "foo", true)
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "project", typedTableID2.ProjectID())
assert.Equal(t, "dataset", typedTableID2.Dataset())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
{
// Table name that does not need escaping:
tableID := NewTableIdentifier("project", "dataset", "foo")
assert.Equal(t, "`project`.`dataset`.foo", tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, "`project`.`dataset`.foo", tableID.FullyQualifiedName(true), "escaped + upper")
}
{
// Table name that needs escaping:
tableID := NewTableIdentifier("project", "dataset", "table")
assert.Equal(t, "`project`.`dataset`.`table`", tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, "`project`.`dataset`.`TABLE`", tableID.FullyQualifiedName(true), "escaped + upper")
}
// Table name that does not need escaping:
assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", false).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", true).FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table", false).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.`TABLE`", NewTableIdentifier("project", "dataset", "table", true).FullyQualifiedName())
}
2 changes: 1 addition & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
tempTableName := tempTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
tempTableName := tempTableID.FullyQualifiedName()

if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (s *Store) Append(tableData *optimization.TableData) error {

// specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name.
func (s *Store) specificIdentifierFor(topicConfig kafkalib.TopicConfig, table string) TableIdentifier {
return NewTableIdentifier(getSchema(topicConfig.Schema), table)
return NewTableIdentifier(getSchema(topicConfig.Schema), table, s.ShouldUppercaseEscapedNames())
}

// IdentifierFor returns a generic [types.TableIdentifier] interface for a [TopicConfig] + table name.
Expand Down
4 changes: 2 additions & 2 deletions clients/mssql/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ func TestTempTableName(t *testing.T) {
// Schema is "schema":
tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table")
tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName()
assert.Equal(t, "schema.table___artie_sUfFiX", trimTTL(tempTableName))
}
{
// Schema is "public" -> "dbo":
tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "public"}, "table")
tableID := store.IdentifierFor(tableData.TopicConfig(), tableData.Name())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(store.ShouldUppercaseEscapedNames())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName()
assert.Equal(t, "dbo.table___artie_sUfFiX", trimTTL(tempTableName))
}
}
15 changes: 8 additions & 7 deletions clients/mssql/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
)

type TableIdentifier struct {
schema string
table string
schema string
table string
uppercaseEscapedNames bool
}

func NewTableIdentifier(schema, table string) TableIdentifier {
return TableIdentifier{schema: schema, table: table}
func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier {
return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames}
}

func (ti TableIdentifier) Schema() string {
Expand All @@ -26,13 +27,13 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.schema, table)
return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames)
}

func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string {
func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
)
}
22 changes: 9 additions & 13 deletions clients/mssql/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,21 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("schema", "foo")
tableID := NewTableIdentifier("schema", "foo", true)
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "schema", typedTableID2.Schema())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
{
// Table name that does not need escaping:
tableID := NewTableIdentifier("schema", "foo")
assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(true), "escaped + upper")
}
{
// Table name that needs escaping:
tableID := NewTableIdentifier("schema", "table")
assert.Equal(t, `schema."table"`, tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, `schema."TABLE"`, tableID.FullyQualifiedName(true), "escaped + upper")
}
// Table name that does not need escaping:
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName())
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName())
assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName())
}
6 changes: 3 additions & 3 deletions clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type Store struct {
}

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier {
return NewTableIdentifier(topicConfig.Schema, table)
return NewTableIdentifier(topicConfig.Schema, table, s.ShouldUppercaseEscapedNames())
}

func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
Expand Down Expand Up @@ -98,8 +98,8 @@ WHERE
}

func (s *Store) Dedupe(tableID types.TableIdentifier) error {
fqTableName := tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName(s.ShouldUppercaseEscapedNames())
fqTableName := tableID.FullyQualifiedName()
stagingTableName := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5))).FullyQualifiedName()

query := fmt.Sprintf(`
CREATE TABLE %s AS SELECT DISTINCT * FROM %s;
Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/redshift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ func TestTempTableName(t *testing.T) {

tableData := optimization.NewTableData(nil, config.Replication, nil, kafkalib.TopicConfig{Database: "db", Schema: "schema"}, "table")
tableID := (&Store{}).IdentifierFor(tableData.TopicConfig(), tableData.Name())
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName(false)
tempTableName := shared.TempTableID(tableID, "sUfFiX").FullyQualifiedName()
assert.Equal(t, "schema.table___artie_sUfFiX", trimTTL(tempTableName))
}
2 changes: 1 addition & 1 deletion clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, _ bool) error {
tempTableName := tempTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
tempTableName := tempTableID.FullyQualifiedName()

// Redshift always creates a temporary table.
tempAlterTableArgs := ddl.AlterTableArgs{
Expand Down
15 changes: 8 additions & 7 deletions clients/redshift/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import (
)

type TableIdentifier struct {
schema string
table string
schema string
table string
uppercaseEscapedNames bool
}

func NewTableIdentifier(schema, table string) TableIdentifier {
return TableIdentifier{schema: schema, table: table}
func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier {
return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames}
}

func (ti TableIdentifier) Schema() string {
Expand All @@ -26,15 +27,15 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.schema, table)
return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames)
}

func (ti TableIdentifier) FullyQualifiedName(uppercaseEscNames bool) string {
func (ti TableIdentifier) FullyQualifiedName() string {
// Redshift is Postgres compatible, so when establishing a connection, we'll specify a database.
// Thus, we only need to specify schema and table name here.
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}),
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}),
)
}
22 changes: 9 additions & 13 deletions clients/redshift/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,21 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("schema", "foo")
tableID := NewTableIdentifier("schema", "foo", true)
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "schema", typedTableID2.Schema())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
{
// Table name that does not need escaping:
tableID := NewTableIdentifier("schema", "foo")
assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, "schema.foo", tableID.FullyQualifiedName(true), "escaped + upper")
}
{
// Table name that needs escaping:
tableID := NewTableIdentifier("schema", "table")
assert.Equal(t, `schema."table"`, tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, `schema."TABLE"`, tableID.FullyQualifiedName(true), "escaped + upper")
}
// Table name that does not need escaping:
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName())
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName())
assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName())
}
5 changes: 3 additions & 2 deletions clients/redshift/writes.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ func (s *Store) Append(tableData *optimization.TableData) error {
// Redshift is slightly different, we'll load and create the temporary table via shared.Append
// Then, we'll invoke `ALTER TABLE target APPEND FROM staging` to combine the diffs.
temporaryTableID := shared.TempTableID(tableID, tableData.TempTableSuffix())
temporaryTableName := temporaryTableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames())
if err := shared.Append(s, tableData, s.config, types.AppendOpts{TempTableID: temporaryTableID}); err != nil {
return err
}

_, err := s.Exec(fmt.Sprintf(`ALTER TABLE %s APPEND FROM %s;`, tableID.FullyQualifiedName(s.ShouldUppercaseEscapedNames()), temporaryTableName))
_, err := s.Exec(
fmt.Sprintf(`ALTER TABLE %s APPEND FROM %s;`, tableID.FullyQualifiedName(), temporaryTableID.FullyQualifiedName()),
)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) ty
// > optionalPrefix/fullyQualifiedTableName/YYYY-MM-DD
func (s *Store) ObjectPrefix(tableData *optimization.TableData) string {
tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name())
fqTableName := tableID.FullyQualifiedName(false)
fqTableName := tableID.FullyQualifiedName()
yyyyMMDDFormat := tableData.LatestCDCTs.Format(ext.PostgresDateFormat)

if len(s.config.S3.OptionalPrefix) > 0 {
Expand Down
2 changes: 1 addition & 1 deletion clients/s3/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.database, ti.schema, table)
}

func (ti TableIdentifier) FullyQualifiedName(_ bool) string {
func (ti TableIdentifier) FullyQualifiedName() string {
// S3 should be db.schema.tableName, but we don't need to escape, since it's not a SQL db.
return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, ti.table)
}
3 changes: 1 addition & 2 deletions clients/s3/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ func TestTableIdentifier_WithTable(t *testing.T) {
func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
// S3 doesn't escape the table name.
tableID := NewTableIdentifier("database", "schema", "table")
assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName(false), "escaped")
assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName(true), "escaped + upper")
assert.Equal(t, "database.schema.table", tableID.FullyQualifiedName())
}
4 changes: 2 additions & 2 deletions clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
tableData.TopicConfig().IncludeDatabaseUpdatedAt, tableData.Mode())

tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
fqName := tableID.FullyQualifiedName(dwh.ShouldUppercaseEscapedNames())
fqName := tableID.FullyQualifiedName()
createAlterTableArgs := ddl.AlterTableArgs{
Dwh: dwh,
Tc: tableConfig,
Expand Down Expand Up @@ -73,7 +73,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
tableConfig.AuditColumnsToDelete(srcKeysMissing)
tableData.MergeColumnsFromDestination(tableConfig.Columns().GetColumns()...)
temporaryTableID := TempTableID(dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()), tableData.TempTableSuffix())
temporaryTableName := temporaryTableID.FullyQualifiedName(dwh.ShouldUppercaseEscapedNames())
temporaryTableName := temporaryTableID.FullyQualifiedName()
if err = dwh.PrepareTemporaryTable(tableData, tableConfig, temporaryTableID, types.AdditionalSettings{}, true); err != nil {
return fmt.Errorf("failed to prepare temporary table: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion clients/shared/table_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (g GetTableCfgArgs) ShouldParseComment(comment string) bool {
}

func (g GetTableCfgArgs) GetTableConfig() (*types.DwhTableConfig, error) {
fqName := g.TableID.FullyQualifiedName(g.Dwh.ShouldUppercaseEscapedNames())
fqName := g.TableID.FullyQualifiedName()

// Check if it already exists in cache
tableConfig := g.ConfigMap.TableConfig(fqName)
Expand Down
Loading

0 comments on commit f4a2045

Please sign in to comment.