Skip to content

Commit

Permalink
[snowflake] Always uppercase escaped column names
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Apr 29, 2024
1 parent 42a37a5 commit f7ad5e6
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 36 deletions.
2 changes: 1 addition & 1 deletion clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *Store) Label() constants.DestinationKind {
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return s.config.SharedDestinationConfig.UppercaseEscapedNames
return true
}

func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
Expand Down
28 changes: 16 additions & 12 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,64 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
// Dedupe with one primary key + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableID.Table()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC) = 2)`, stagingTableName),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2])
}
{
// Dedupe with one primary key + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "customers")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"id"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableName),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING "%s" t2 WHERE t1.id = t2.id`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM "%s"`, stagingTableName), parts[2])
}
{
// Dedupe with composite keys + no `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.Table()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableName),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2])
}
{
// Dedupe with composite keys + `__artie_updated_at` flag.
tableID := NewTableIdentifier("db", "public", "user_settings")
stagingTableID := shared.TempTableID(tableID, strings.ToLower(stringutil.Random(5)))
stagingTableName := strings.ToUpper(stagingTableID.Table())

parts := s.stageStore.generateDedupeQueries(tableID, stagingTableID, []string{"user_id", "settings"}, kafkalib.TopicConfig{IncludeArtieUpdatedAt: true})
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.Table()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE "%s" AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY by user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableName),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.Table()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableID.Table()), parts[2])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING "%s" t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableName), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM "%s"`, stagingTableName), parts[2])
}
}
8 changes: 0 additions & 8 deletions lib/config/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,3 @@ func (b *BigQuery) DSN() string {

return dsn
}

func (c Config) ValidateBigQuery() error {
if c.SharedDestinationConfig.UppercaseEscapedNames {
return fmt.Errorf("uppercaseEscapedNames is not supported for BigQuery")
}

return nil
}
11 changes: 0 additions & 11 deletions lib/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,6 @@ type Config struct {
// Shared Transfer settings
SharedTransferConfig SharedTransferConfig `yaml:"sharedTransferConfig"`

// Shared destination configuration
SharedDestinationConfig SharedDestinationConfig `yaml:"sharedDestinationConfig"`

// Supported destinations
MSSQL *MSSQL `yaml:"mssql,omitempty"`
BigQuery *BigQuery `yaml:"bigquery,omitempty"`
Expand Down Expand Up @@ -240,10 +237,6 @@ func (c Config) ValidateRedshift() error {
return fmt.Errorf("invalid Redshift port")
}

if c.SharedDestinationConfig.UppercaseEscapedNames {
return fmt.Errorf("uppercaseEscapedNames is not supported for Redshift")
}

return nil
}

Expand All @@ -269,10 +262,6 @@ func (c Config) Validate() error {
}

switch c.Output {
case constants.BigQuery:
if err := c.ValidateBigQuery(); err != nil {
return err
}
case constants.MSSQL:
if err := c.ValidateMSSQL(); err != nil {
return err
Expand Down
4 changes: 0 additions & 4 deletions lib/config/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,5 @@ func (c Config) ValidateMSSQL() error {
return fmt.Errorf("invalid mssql port: %d", c.MSSQL.Port)
}

if c.SharedDestinationConfig.UppercaseEscapedNames {
return fmt.Errorf("uppercaseEscapedNames is not supported for MS SQL")
}

return nil
}

0 comments on commit f7ad5e6

Please sign in to comment.