From 9b90faccc1cf51a7cca66d0fa55fd9f0d496ff07 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:21:40 -0700 Subject: [PATCH 1/6] [snowflake] Always escape column identifiers --- clients/snowflake/snowflake.go | 6 ++-- clients/snowflake/tableid.go | 2 +- lib/destination/dml/merge.go | 6 ++-- lib/destination/dml/merge_test.go | 10 +++--- lib/destination/dml/merge_valid_test.go | 2 +- lib/sql/dialect.go | 47 ++++++++++--------------- lib/sql/dialect_test.go | 40 ++++++++------------- lib/sql/escape.go | 8 ----- lib/typing/columns/columns.go | 2 +- lib/typing/columns/columns_test.go | 6 ++-- lib/typing/columns/default_test.go | 2 +- lib/typing/columns/wrapper_test.go | 4 +-- 12 files changed, 54 insertions(+), 81 deletions(-) delete mode 100644 lib/sql/escape.go diff --git a/clients/snowflake/snowflake.go b/clients/snowflake/snowflake.go index 642f8c5a2..8c62c08be 100644 --- a/clients/snowflake/snowflake.go +++ b/clients/snowflake/snowflake.go @@ -78,7 +78,7 @@ func (s *Store) Label() constants.DestinationKind { } func (s *Store) Dialect() sql.Dialect { - return sql.SnowflakeDialect{UppercaseEscNames: s.config.SharedDestinationConfig.UppercaseEscapedNames} + return sql.SnowflakeDialect{LegacyMode: !s.config.SharedDestinationConfig.UppercaseEscapedNames} } func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap { @@ -130,12 +130,12 @@ func (s *Store) reestablishConnection() error { func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { var primaryKeysEscaped []string for _, pk := range primaryKeys { - primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.Dialect())) + primaryKeysEscaped = append(primaryKeysEscaped, s.Dialect().QuoteIdentifier(pk)) } orderColsToIterate := primaryKeysEscaped if topicConfig.IncludeArtieUpdatedAt { - orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.Dialect())) + orderColsToIterate = append(orderColsToIterate, s.Dialect().QuoteIdentifier(constants.UpdateColumnMarker)) } var orderByCols []string diff --git a/clients/snowflake/tableid.go b/clients/snowflake/tableid.go index 662b97f75..ec9cbb2db 100644 --- a/clients/snowflake/tableid.go +++ b/clients/snowflake/tableid.go @@ -7,7 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/sql" ) -var dialect = sql.SnowflakeDialect{UppercaseEscNames: true} +var dialect = sql.SnowflakeDialect{} type TableIdentifier struct { database string diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 428154877..2ca9cb393 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -129,7 +129,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) { + if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break @@ -252,7 +252,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) { + if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break @@ -322,7 +322,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, // We also need to remove __artie flags since it does not exist in the destination table var removed bool for idx, col := range cols { - if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) { + if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) { cols = append(cols[:idx], cols[idx+1:]...) removed = true break diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 5987fedd7..2cd93ae9e 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -56,7 +56,7 @@ func TestMergeStatementSoftDelete(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} for _, idempotentKey := range []string{"", "updated_at"} { mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, @@ -107,7 +107,7 @@ func TestMergeStatement(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -156,7 +156,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -199,7 +199,7 @@ func TestMergeStatementCompositeKey(t *testing.T) { _cols.AddColumn(columns.NewColumn("another_id", typing.String)) _cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, @@ -249,7 +249,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - dialect := sql.SnowflakeDialect{UppercaseEscNames: true} + dialect := sql.SnowflakeDialect{} mergeArg := MergeArgument{ TableID: MockTableIdentifier{fqTable}, SubQuery: subQuery, diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index cd3fe22d2..71101833b 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -14,7 +14,7 @@ import ( func TestMergeArgument_Valid(t *testing.T) { primaryKeys := []columns.Wrapper{ - columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}), + columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{}), } var cols columns.Columns diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index b19de75b1..d70d479cc 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -11,15 +11,12 @@ import ( ) type Dialect interface { - NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything QuoteIdentifier(identifier string) string EscapeStruct(value string) string } type BigQueryDialect struct{} -func (BigQueryDialect) NeedsEscaping(_ string) bool { return true } - func (BigQueryDialect) QuoteIdentifier(identifier string) string { // BigQuery needs backticks to quote. return fmt.Sprintf("`%s`", identifier) @@ -31,8 +28,6 @@ func (BigQueryDialect) EscapeStruct(value string) string { type MSSQLDialect struct{} -func (MSSQLDialect) NeedsEscaping(_ string) bool { return true } - func (MSSQLDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } @@ -43,8 +38,6 @@ func (MSSQLDialect) EscapeStruct(value string) string { type RedshiftDialect struct{} -func (RedshiftDialect) NeedsEscaping(_ string) bool { return true } - func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { // Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted. return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) @@ -55,35 +48,33 @@ func (RedshiftDialect) EscapeStruct(value string) string { } type SnowflakeDialect struct { - UppercaseEscNames bool + LegacyMode bool } -func (sd SnowflakeDialect) NeedsEscaping(name string) bool { - if sd.UppercaseEscNames { - // If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix. - // Since they will be uppercased afer they are escaped then they will result in the same value as if we - // we were to use them in a query without any escaping at all. +func (sd SnowflakeDialect) legacyNeedsEscaping(name string) bool { + if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") { return true - } else { - if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") { - return true - } - // If it still doesn't need to be escaped, we should check if it's a number. - if _, err := strconv.Atoi(name); err == nil { - return true - } - return false } + // If it still doesn't need to be escaped, we should check if it's a number. + if _, err := strconv.Atoi(name); err == nil { + return true + } + return false } func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { - if sd.UppercaseEscNames { - identifier = strings.ToUpper(identifier) + if sd.LegacyMode { + if sd.legacyNeedsEscaping(identifier) { + // In legacy mode we would have escaped this identifier which would have caused it to be lowercase. + slog.Warn("Escaped Snowflake identifier is not being uppercased", + slog.String("name", identifier), + ) + } else { + // Since this identifier wasn't previously escaped it will have been used uppercase. + identifier = strings.ToUpper(identifier) + } } else { - slog.Warn("Escaped Snowflake identifier is not being uppercased", - slog.String("name", identifier), - slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames), - ) + identifier = strings.ToUpper(identifier) } return fmt.Sprintf(`"%s"`, identifier) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index adee0a481..744e033c7 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -24,39 +24,29 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) { assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO")) } -func TestSnowflakeDialect_NeedsEscaping(t *testing.T) { - { - // UppercaseEscNames enabled: - dialect := SnowflakeDialect{UppercaseEscNames: true} - - assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved - assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved - assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix - assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol - } - - { - // UppercaseEscNames disabled: - dialect := SnowflakeDialect{UppercaseEscNames: false} - - assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved - assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved - assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix - assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol - } +func TestSnowflakeDialect_legacyNeedsEscaping(t *testing.T) { + dialect := SnowflakeDialect{} + assert.True(t, dialect.legacyNeedsEscaping("select")) // name that is reserved + assert.False(t, dialect.legacyNeedsEscaping("foo")) // name that is not reserved + assert.False(t, dialect.legacyNeedsEscaping("__artie_foo")) // Artie prefix + assert.True(t, dialect.legacyNeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol } func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { { - // UppercaseEscNames enabled: - dialect := SnowflakeDialect{UppercaseEscNames: true} + // New mode: + dialect := SnowflakeDialect{LegacyMode: false} assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) } { - // UppercaseEscNames disabled: - dialect := SnowflakeDialect{UppercaseEscNames: false} - assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo")) + // Legacy mode: + dialect := SnowflakeDialect{LegacyMode: true} + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) + assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name + assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name + assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name + assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name } } diff --git a/lib/sql/escape.go b/lib/sql/escape.go deleted file mode 100644 index c1f34bda7..000000000 --- a/lib/sql/escape.go +++ /dev/null @@ -1,8 +0,0 @@ -package sql - -func EscapeNameIfNecessary(name string, dialect Dialect) string { - if dialect.NeedsEscaping(name) { - return dialect.QuoteIdentifier(name) - } - return name -} diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 6623d61c6..a6b752a5c 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -85,7 +85,7 @@ func (c *Column) RawName() string { // Name will give you c.name and escape it if necessary. func (c *Column) Name(dialect sql.Dialect) string { - return sql.EscapeNameIfNecessary(c.name, dialect) + return dialect.QuoteIdentifier(c.name) } type Columns struct { diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index 88d6b15ea..80e392b47 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -170,7 +170,7 @@ func TestColumn_Name(t *testing.T) { assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName) - assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName) + assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName) assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName) } } @@ -282,7 +282,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name) assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name) } } @@ -486,7 +486,7 @@ func TestColumnsUpdateQuery(t *testing.T) { { name: "string and toast", columns: stringAndToastCols, - dialect: sql.SnowflakeDialect{UppercaseEscNames: true}, + dialect: sql.SnowflakeDialect{}, expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`, }, { diff --git a/lib/typing/columns/default_test.go b/lib/typing/columns/default_test.go index 9abbdc6a2..0da9f8cfe 100644 --- a/lib/typing/columns/default_test.go +++ b/lib/typing/columns/default_test.go @@ -17,7 +17,7 @@ import ( var dialects = []sql.Dialect{ sql.BigQueryDialect{}, sql.RedshiftDialect{}, - sql.SnowflakeDialect{UppercaseEscNames: true}, + sql.SnowflakeDialect{}, } func TestColumn_DefaultValue(t *testing.T) { diff --git a/lib/typing/columns/wrapper_test.go b/lib/typing/columns/wrapper_test.go index 8f7e1f317..95a54a649 100644 --- a/lib/typing/columns/wrapper_test.go +++ b/lib/typing/columns/wrapper_test.go @@ -41,7 +41,7 @@ func TestWrapper_Complete(t *testing.T) { for _, testCase := range testCases { // Snowflake escape - w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) + w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{}) assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) @@ -53,7 +53,7 @@ func TestWrapper_Complete(t *testing.T) { assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) { - w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true}) + w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{}) assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name) } { From 6647e6ae527fcd737acc74fd9efc80fbd77c73d6 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:26:49 -0700 Subject: [PATCH 2/6] Test --- lib/sql/dialect_test.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index 744e033c7..dbdb07f9f 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -44,9 +44,10 @@ func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { dialect := SnowflakeDialect{LegacyMode: true} assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) - assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name - assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name - assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name - assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name + assert.Equal(t, `"abc:def"`, dialect.QuoteIdentifier("abc:def")) // Symbol + assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name + assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name + assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name + assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name } } From 7285b30e5af674cf444ba1ee7e3f8eebc2336032 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:28:07 -0700 Subject: [PATCH 3/6] Caps --- lib/sql/dialect_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index dbdb07f9f..cbdfbe7e2 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -24,7 +24,7 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) { assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO")) } -func TestSnowflakeDialect_legacyNeedsEscaping(t *testing.T) { +func TestSnowflakeDialect_LegacyNeedsEscaping(t *testing.T) { dialect := SnowflakeDialect{} assert.True(t, dialect.legacyNeedsEscaping("select")) // name that is reserved assert.False(t, dialect.legacyNeedsEscaping("foo")) // name that is not reserved From 05ad468a3a2c00b1a77ff261a9ccc54ff74f22b4 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:29:14 -0700 Subject: [PATCH 4/6] More test cases --- lib/sql/dialect_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index cbdfbe7e2..66648d96f 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -38,6 +38,8 @@ func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { dialect := SnowflakeDialect{LegacyMode: false} assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) + assert.Equal(t, `"SELECT"`, dialect.QuoteIdentifier("select")) + assert.Equal(t, `"GROUP"`, dialect.QuoteIdentifier("group")) } { // Legacy mode: From 05c4e3958ae2a605643ca8a12056ab969b111b56 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:29:52 -0700 Subject: [PATCH 5/6] Caps --- lib/sql/dialect_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index 66648d96f..c1fc565c5 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -46,10 +46,10 @@ func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { dialect := SnowflakeDialect{LegacyMode: true} assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) - assert.Equal(t, `"abc:def"`, dialect.QuoteIdentifier("abc:def")) // Symbol - assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name - assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name - assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name - assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name + assert.Equal(t, `"abc:def"`, dialect.QuoteIdentifier("abc:def")) // symbol + assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // reserved name + assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // reserved name + assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // reserved name + assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // reserved name } } From 3339aefd36ecaf665850a376539e9e3ba3546482 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 19:31:11 -0700 Subject: [PATCH 6/6] Simplify --- lib/sql/dialect.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index d70d479cc..f2d1e97cb 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -4,7 +4,6 @@ import ( "fmt" "log/slog" "slices" - "strconv" "strings" "github.com/artie-labs/transfer/lib/config/constants" @@ -52,14 +51,7 @@ type SnowflakeDialect struct { } func (sd SnowflakeDialect) legacyNeedsEscaping(name string) bool { - if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") { - return true - } - // If it still doesn't need to be escaped, we should check if it's a number. - if _, err := strconv.Atoi(name); err == nil { - return true - } - return false + return slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") } func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {