Skip to content

Commit

Permalink
[snowflake] Always uppercase escaped table names (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Apr 23, 2024
1 parent 91f728e commit 3d104d2
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 37 deletions.
4 changes: 2 additions & 2 deletions clients/snowflake/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
)

func (s *SnowflakeTestSuite) TestMutateColumnsWithMemoryCacheDeletions() {
tableID := NewTableIdentifier("coffee_shop", "public", "orders", true)
tableID := NewTableIdentifier("coffee_shop", "public", "orders")

var cols columns.Columns
for colName, kindDetails := range map[string]typing.KindDetails{
Expand Down Expand Up @@ -51,7 +51,7 @@ func (s *SnowflakeTestSuite) TestMutateColumnsWithMemoryCacheDeletions() {
}

func (s *SnowflakeTestSuite) TestShouldDeleteColumn() {
tableID := NewTableIdentifier("coffee_shop", "orders", "public", true)
tableID := NewTableIdentifier("coffee_shop", "orders", "public")
var cols columns.Columns
for colName, kindDetails := range map[string]typing.KindDetails{
"id": typing.Integer,
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (
)

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier {
return NewTableIdentifier(topicConfig.Database, topicConfig.Schema, table, s.ShouldUppercaseEscapedNames())
return NewTableIdentifier(topicConfig.Database, topicConfig.Schema, table)
}

func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) {
Expand Down
4 changes: 2 additions & 2 deletions clients/snowflake/staging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (s *SnowflakeTestSuite) TestCastColValStaging() {
}

func (s *SnowflakeTestSuite) TestBackfillColumn() {
tableID := NewTableIdentifier("db", "public", "tableName", true)
tableID := NewTableIdentifier("db", "public", "tableName")

backfilledCol := columns.NewColumn("foo", typing.Boolean)
backfilledCol.SetDefaultValue(true)
Expand Down Expand Up @@ -130,7 +130,7 @@ func generateTableData(rows int) (TableIdentifier, *optimization.TableData) {
td.InsertRow(key, rowData, false)
}

return NewTableIdentifier("database", "schema", randomTableName, true), td
return NewTableIdentifier("database", "schema", randomTableName), td
}

func (s *SnowflakeTestSuite) TestPrepareTempTable() {
Expand Down
20 changes: 9 additions & 11 deletions clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@ import (
)

type TableIdentifier struct {
database string
schema string
table string
uppercaseEscapedNames bool
database string
schema string
table string
}

func NewTableIdentifier(database, schema, table string, uppercaseEscapedNames bool) TableIdentifier {
func NewTableIdentifier(database, schema, table string) TableIdentifier {
return TableIdentifier{
database: database,
schema: schema,
table: table,
uppercaseEscapedNames: uppercaseEscapedNames,
database: database,
schema: schema,
table: table,
}
}

Expand All @@ -37,14 +35,14 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.database, ti.schema, table, ti.uppercaseEscapedNames)
return NewTableIdentifier(ti.database, ti.schema, table)
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s.%s",
ti.database,
ti.schema,
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, constants.Snowflake),
sql.EscapeNameIfNecessary(ti.table, true, constants.Snowflake),
)
}
13 changes: 5 additions & 8 deletions clients/snowflake/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("database", "schema", "foo", true)
tableID := NewTableIdentifier("database", "schema", "foo")
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "database", typedTableID2.Database())
assert.Equal(t, "schema", typedTableID2.Schema())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
// Table name that does not need escaping:
assert.Equal(t, "database.schema.foo", NewTableIdentifier("database", "schema", "foo", false).FullyQualifiedName())
assert.Equal(t, "database.schema.foo", NewTableIdentifier("database", "schema", "foo", true).FullyQualifiedName())
// Table name that is not a reserved word:
assert.Equal(t, "database.schema.foo", NewTableIdentifier("database", "schema", "foo").FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, `database.schema."table"`, NewTableIdentifier("database", "schema", "table", false).FullyQualifiedName())
assert.Equal(t, `database.schema."TABLE"`, NewTableIdentifier("database", "schema", "table", true).FullyQualifiedName())
// Table name that is a reserved word:
assert.Equal(t, `database.schema."TABLE"`, NewTableIdentifier("database", "schema", "table").FullyQualifiedName())
}
8 changes: 4 additions & 4 deletions clients/snowflake/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ func TestAddPrefixToTableName(t *testing.T) {
testCases := []_testCase{
{
name: "happy path",
tableID: NewTableIdentifier("database", "schema", "tableName", true),
tableID: NewTableIdentifier("database", "schema", "tableName"),
expectedFqTableName: "database.schema.%tableName",
},
{
name: "tableName only",
tableID: NewTableIdentifier("", "", "orders", true),
tableID: NewTableIdentifier("", "", "orders"),
expectedFqTableName: "..%orders",
},
{
name: "schema and tableName only",
tableID: NewTableIdentifier("", "public", "orders", true),
tableID: NewTableIdentifier("", "public", "orders"),
expectedFqTableName: ".public.%orders",
},
{
name: "db and tableName only",
tableID: NewTableIdentifier("db", "", "tableName", true),
tableID: NewTableIdentifier("db", "", "tableName"),
expectedFqTableName: "db..%tableName",
},
}
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (d *DDLTestSuite) Test_CreateTable() {
bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table")
d.bigQueryStore.GetConfigMap().AddTableToConfig(bqTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))

snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", true)
snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table")
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(snowflakeTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))

type dwhToTableConfig struct {
Expand Down Expand Up @@ -114,7 +114,7 @@ func (d *DDLTestSuite) TestCreateTable() {
}

for index, testCase := range testCases {
tableID := snowflake.NewTableIdentifier("demo", "public", "experiments", false)
tableID := snowflake.NewTableIdentifier("demo", "public", "experiments")
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))
tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)

Expand Down
10 changes: 5 additions & 5 deletions lib/destination/ddl/ddl_sflk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (d *DDLTestSuite) TestAlterComplexObjects() {
columns.NewColumn("select", typing.String),
}

tableID := snowflake.NewTableIdentifier("shop", "public", "complex_columns", true)
tableID := snowflake.NewTableIdentifier("shop", "public", "complex_columns")
fqTable := "shop.public.complex_columns"
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true))
tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)
Expand Down Expand Up @@ -64,7 +64,7 @@ func (d *DDLTestSuite) TestAlterIdempotency() {
columns.NewColumn("start", typing.String),
}

tableID := snowflake.NewTableIdentifier("shop", "public", "orders", true)
tableID := snowflake.NewTableIdentifier("shop", "public", "orders")
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true))
tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)

Expand Down Expand Up @@ -95,7 +95,7 @@ func (d *DDLTestSuite) TestAlterTableAdd() {
columns.NewColumn("start", typing.String),
}

tableID := snowflake.NewTableIdentifier("shop", "public", "orders", true)
tableID := snowflake.NewTableIdentifier("shop", "public", "orders")
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true))
tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)

Expand Down Expand Up @@ -138,7 +138,7 @@ func (d *DDLTestSuite) TestAlterTableDeleteDryRun() {
columns.NewColumn("start", typing.String),
}

tableID := snowflake.NewTableIdentifier("shop", "public", "users", true)
tableID := snowflake.NewTableIdentifier("shop", "public", "users")
fqTable := "shop.public.users"
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, false, true))
tc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)
Expand Down Expand Up @@ -198,7 +198,7 @@ func (d *DDLTestSuite) TestAlterTableDelete() {
columns.NewColumn("start", typing.String),
}

tableID := snowflake.NewTableIdentifier("shop", "public", "users1", true)
tableID := snowflake.NewTableIdentifier("shop", "public", "users1")

d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, map[string]time.Time{
"col_to_delete": time.Now().Add(-2 * constants.DeletionConfidencePadding),
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl_temp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (d *DDLTestSuite) TestValidate_AlterTableArgs() {
}

func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() {
tableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", false)
tableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table")
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))
snowflakeTc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)
args := ddl.AlterTableArgs{
Expand Down Expand Up @@ -69,7 +69,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable_Errors() {
func (d *DDLTestSuite) TestCreateTemporaryTable() {
{
// Snowflake Stage
tableID := snowflake.NewTableIdentifier("db", "schema", "tempTableName", false)
tableID := snowflake.NewTableIdentifier("db", "schema", "tempTableName")

d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))
sflkStageTc := d.snowflakeStagesStore.GetConfigMap().TableConfig(tableID)
Expand Down

0 comments on commit 3d104d2

Please sign in to comment.