diff --git a/clients/snowflake/snowflake_dedupe_test.go b/clients/snowflake/snowflake_dedupe_test.go index b9b5aed83..549d39f17 100644 --- a/clients/snowflake/snowflake_dedupe_test.go +++ b/clients/snowflake/snowflake_dedupe_test.go @@ -20,10 +20,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC) = 2)`, stagingTableID.FullyQualifiedName()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "ID" ORDER BY "ID" ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1."ID" = t2."ID"`, stagingTableID.FullyQualifiedName()), parts[1]) assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { @@ -35,10 +35,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "ID" ORDER BY "ID" ASC, "__ARTIE_UPDATED_AT" ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1."ID" = t2."ID"`, stagingTableID.FullyQualifiedName()), parts[1]) assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { @@ -50,10 +50,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.FullyQualifiedName()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "USER_ID", "SETTINGS" ORDER BY "USER_ID" ASC, "SETTINGS" ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1."USER_ID" = t2."USER_ID" AND t1."SETTINGS" = t2."SETTINGS"`, stagingTableID.FullyQualifiedName()), parts[1]) assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } { @@ -65,10 +65,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() { assert.Len(s.T(), parts, 3) assert.Equal( s.T(), - fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()), + fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "USER_ID", "SETTINGS" ORDER BY "USER_ID" ASC, "SETTINGS" ASC, "__ARTIE_UPDATED_AT" ASC) = 2)`, stagingTableID.FullyQualifiedName()), parts[0], ) - assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1]) + assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1."USER_ID" = t2."USER_ID" AND t1."SETTINGS" = t2."SETTINGS"`, stagingTableID.FullyQualifiedName()), parts[1]) assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2]) } } diff --git a/clients/snowflake/staging_test.go b/clients/snowflake/staging_test.go index 8d9c09566..f75bb9a90 100644 --- a/clients/snowflake/staging_test.go +++ b/clients/snowflake/staging_test.go @@ -79,14 +79,14 @@ func (s *SnowflakeTestSuite) TestBackfillColumn() { { name: "col that has default value that needs to be backfilled", col: needsBackfillCol, - backfillSQL: `UPDATE db.public."TABLENAME" SET foo = true WHERE foo IS NULL;`, - commentSQL: `COMMENT ON COLUMN db.public."TABLENAME".foo IS '{"backfilled": true}';`, + backfillSQL: `UPDATE db.public."TABLENAME" SET "FOO" = true WHERE "FOO" IS NULL;`, + commentSQL: `COMMENT ON COLUMN db.public."TABLENAME"."FOO" IS '{"backfilled": true}';`, }, { name: "default col that has default value that needs to be backfilled", col: needsBackfillColDefault, - backfillSQL: `UPDATE db.public."TABLENAME" SET default = true WHERE "DEFAULT" IS NULL;`, - commentSQL: `COMMENT ON COLUMN db.public."TABLENAME".default IS '{"backfilled": true}';`, + backfillSQL: `UPDATE db.public."TABLENAME" SET "DEFAULT" = true WHERE "DEFAULT" IS NULL;`, + commentSQL: `COMMENT ON COLUMN db.public."TABLENAME"."DEFAULT" IS '{"backfilled": true}';`, }, } @@ -147,7 +147,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() { createQuery, _ := s.fakeStageStore.ExecArgsForCall(0) prefixQuery := fmt.Sprintf( - `CREATE TABLE IF NOT EXISTS %s (user_id string,first_name string,last_name string,dusty string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName) + `CREATE TABLE IF NOT EXISTS %s ("USER_ID" string,"FIRST_NAME" string,"LAST_NAME" string,"DUSTY" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName) containsPrefix := strings.HasPrefix(createQuery, prefixQuery) assert.True(s.T(), containsPrefix, fmt.Sprintf("createQuery:%v, prefixQuery:%s", createQuery, prefixQuery)) resourceName := addPrefixToTableName(tempTableID, "%") @@ -157,7 +157,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() { assert.Contains(s.T(), putQuery, fmt.Sprintf("@%s AUTO_COMPRESS=TRUE", resourceName)) // Third call is a COPY INTO copyQuery, _ := s.fakeStageStore.ExecArgsForCall(2) - assert.Equal(s.T(), fmt.Sprintf(`COPY INTO %s (user_id,first_name,last_name,dusty) FROM (SELECT $1,$2,$3,$4 FROM @%s)`, + assert.Equal(s.T(), fmt.Sprintf(`COPY INTO %s ("USER_ID","FIRST_NAME","LAST_NAME","DUSTY") FROM (SELECT $1,$2,$3,$4 FROM @%s)`, tempTableName, resourceName), copyQuery) } { diff --git a/clients/snowflake/tableid.go b/clients/snowflake/tableid.go index 662b97f75..ec9cbb2db 100644 --- a/clients/snowflake/tableid.go +++ b/clients/snowflake/tableid.go @@ -7,7 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/sql" ) -var dialect = sql.SnowflakeDialect{UppercaseEscNames: true} +var dialect = sql.SnowflakeDialect{} type TableIdentifier struct { database string diff --git a/lib/destination/ddl/ddl_suite_test.go b/lib/destination/ddl/ddl_suite_test.go index 299ea10ad..f9b100188 100644 --- a/lib/destination/ddl/ddl_suite_test.go +++ b/lib/destination/ddl/ddl_suite_test.go @@ -46,9 +46,6 @@ func (d *DDLTestSuite) SetupTest() { snowflakeStagesStore := db.Store(d.fakeSnowflakeStagesStore) snowflakeCfg := config.Config{ Snowflake: &config.Snowflake{}, - SharedDestinationConfig: config.SharedDestinationConfig{ - UppercaseEscapedNames: true, - }, } d.snowflakeStagesStore, err = snowflake.LoadSnowflake(snowflakeCfg, &snowflakeStagesStore) assert.NoError(d.T(), err) diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 5987fedd7..2cd93ae9e 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -56,7 +56,7 @@ func TestMergeStatementSoftDelete(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} for _, idempotentKey := range []string{"", "updated_at"} { mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, @@ -107,7 +107,7 @@ func TestMergeStatement(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -156,7 +156,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -199,7 +199,7 @@ func TestMergeStatementCompositeKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("another_id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -249,7 +249,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index cd3fe22d2..71101833b 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -14,7 +14,7 @@ import ( func TestMergeArgument_Valid(t *testing.T) { primaryKeys := []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}), + columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{}), } var cols columns.Columns diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index ae39b254e..30f3b9ebf 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -2,12 +2,8 @@ package sql import ( "fmt" - "log/slog" - "slices" - "strconv" "strings" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/stringutil" ) @@ -55,39 +51,14 @@ func (RedshiftDialect) EscapeStruct(value any) string { return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false)) } -type SnowflakeDialect struct { - UppercaseEscNames bool -} +type SnowflakeDialect struct{} func (sd SnowflakeDialect) NeedsEscaping(name string) bool { - if sd.UppercaseEscNames { - // If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix. - // Since they will be uppercased afer they are escaped then they will result in the same value as if we - // we were to use them in a query without any escaping at all. - return true - } else { - if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") { - return true - } - // If it still doesn't need to be escaped, we should check if it's a number. - if _, err := strconv.Atoi(name); err == nil { - return true - } - return false - } + return true } func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { - if sd.UppercaseEscNames { - identifier = strings.ToUpper(identifier) - } else { - slog.Warn("Escaped Snowflake identifier is not being uppercased", - slog.String("name", identifier), - slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames), - ) - } - - return fmt.Sprintf(`"%s"`, identifier) + return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier)) } func (SnowflakeDialect) EscapeStruct(value any) string { diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index adee0a481..0cff59c52 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -25,38 +25,18 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) { } func TestSnowflakeDialect_NeedsEscaping(t *testing.T) { - { - // UppercaseEscNames enabled: - dialect := SnowflakeDialect{UppercaseEscNames: true} - - assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved - assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved - assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix - assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol - } - - { - // UppercaseEscNames disabled: - dialect := SnowflakeDialect{UppercaseEscNames: false} - - assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved - assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved - assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix - assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol - } + // UppercaseEscNames enabled: + dialect := SnowflakeDialect{} + + assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved + assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved + assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix + assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol } func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { - { - // UppercaseEscNames enabled: - dialect := SnowflakeDialect{UppercaseEscNames: true} - assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) - assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) - } - { - // UppercaseEscNames disabled: - dialect := SnowflakeDialect{UppercaseEscNames: false} - assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo")) - assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) - } + // UppercaseEscNames enabled: + dialect := SnowflakeDialect{} + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) } diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 88d6b15ea..80e392b47 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -170,7 +170,7 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName) assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName) } } @@ -282,7 +282,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name) assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name) } } @@ -486,7 +486,7 @@ func TestColumnsUpdateQuery(t *testing.T) { { name: "string and toast", columns: stringAndToastCols, - dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + dialect: sql.SnowflakeDialect{}, expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`, }, { diff --git a/lib/typing/columns/default_test.go b/lib/typing/columns/default_test.go index 9abbdc6a2..0da9f8cfe 100644 --- a/lib/typing/columns/default_test.go +++ b/lib/typing/columns/default_test.go @@ -17,7 +17,7 @@ import ( var dialects = []sql.Dialect{ sql.BigQueryDialect{}, sql.RedshiftDialect{}, - sql.SnowflakeDialect{UppercaseEscNames: true}, + sql.SnowflakeDialect{}, } func TestColumn_DefaultValue(t *testing.T) { diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 8f7e1f317..95a54a649 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -41,7 +41,7 @@ func TestWrapper_Complete(t *testing.T) { for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) + w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{}) assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) @@ -53,7 +53,7 @@ func TestWrapper_Complete(t *testing.T) { assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) { - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{}) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) } {