From 06ccd1d4bdb84573d769c3a0fa9446159666ea18 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 13:23:59 -0700 Subject: [PATCH 1/2] [sql] Move `NeedsEscaping` to `Dialect` --- lib/sql/dialect.go | 35 ++++++++++++++++++++++++++++++ lib/sql/dialect_test.go | 22 +++++++++++++++++++ lib/sql/escape.go | 47 +++++------------------------------------ lib/sql/escape_test.go | 33 ----------------------------- 4 files changed, 62 insertions(+), 75 deletions(-) diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index 780b72797..b444e2401 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -3,21 +3,30 @@ package sql import ( "fmt" "log/slog" + "slices" + "strconv" "strings" + + "github.com/artie-labs/transfer/lib/config/constants" ) type Dialect interface { + NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything QuoteIdentifier(identifier string) string } type DefaultDialect struct{} +func (DefaultDialect) NeedsEscaping(_ string) bool { return true } + func (DefaultDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } type BigQueryDialect struct{} +func (BigQueryDialect) NeedsEscaping(_ string) bool { return true } + func (BigQueryDialect) QuoteIdentifier(identifier string) string { // BigQuery needs backticks to quote. return fmt.Sprintf("`%s`", identifier) @@ -25,6 +34,8 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string { type RedshiftDialect struct{} +func (RedshiftDialect) NeedsEscaping(_ string) bool { return true } + func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { // Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted. return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) @@ -34,6 +45,30 @@ type SnowflakeDialect struct { UppercaseEscNames bool } +func (sd SnowflakeDialect) NeedsEscaping(name string) bool { + if sd.UppercaseEscNames { + // If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix. + // Since they will be uppercased afer they are escaped then they will result in the same value as if we + // we were to use them in a query without any escaping at all. + return true + } else { + if slices.Contains(constants.ReservedKeywords, name) { + return true + } + // If it does not contain any reserved words, does it contain any symbols that need to be escaped? + for _, symbol := range symbolsToEscape { + if strings.Contains(name, symbol) { + return true + } + } + // If it still doesn't need to be escaped, we should check if it's a number. + if _, err := strconv.Atoi(name); err == nil { + return true + } + return false + } +} + func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { if sd.UppercaseEscNames { identifier = strings.ToUpper(identifier) diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go index 3a779cf74..9f1088077 100644 --- a/lib/sql/dialect_test.go +++ b/lib/sql/dialect_test.go @@ -24,6 +24,28 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) { assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO")) } +func TestSnowflakeDialect_NeedsEscaping(t *testing.T) { + { + // UppercaseEscNames enabled: + dialect := SnowflakeDialect{UppercaseEscNames: true} + + assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved + assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved + assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix + assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol + } + + { + // UppercaseEscNames disabled: + dialect := SnowflakeDialect{UppercaseEscNames: false} + + assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved + assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved + assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix + assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol + } +} + func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { { // UppercaseEscNames enabled: diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 28c22d499..f8d1e62dc 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -1,11 +1,6 @@ package sql import ( - "log/slog" - "slices" - "strconv" - "strings" - "github.com/artie-labs/transfer/lib/config/constants" ) @@ -13,45 +8,12 @@ import ( var symbolsToEscape = []string{":"} func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { - if NeedsEscaping(name, uppercaseEscNames, destKind) { - return EscapeName(name, uppercaseEscNames, destKind) - } - return name -} + var dialect = dialectFor(destKind, uppercaseEscNames) -func NeedsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool { - switch destKind { - case constants.BigQuery, constants.MSSQL, constants.Redshift: - return true - case constants.S3: - return false - case constants.Snowflake: - if uppercaseEscNames { - // If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix. - // Since they will be uppercased afer they are escaped then they will result in the same value as if we - // we were to use them in a query without any escaping at all. - return true - } else { - if slices.Contains(constants.ReservedKeywords, name) { - return true - } - // If it does not contain any reserved words, does it contain any symbols that need to be escaped? - for _, symbol := range symbolsToEscape { - if strings.Contains(name, symbol) { - return true - } - } - // If it still doesn't need to be escaped, we should check if it's a number. - if _, err := strconv.Atoi(name); err == nil { - return true - } - } - default: - slog.Error("Unsupported destination kind", slog.String("destKind", string(destKind))) - return true + if destKind != constants.S3 && dialect.NeedsEscaping(name) { + return dialect.QuoteIdentifier(name) } - - return false + return name } func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect { @@ -68,5 +30,6 @@ func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dial } func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { + // TODO: This is only used in one place, remove once [Dialect] has beem added to [Store]. return dialectFor(destKind, uppercaseEscNames).QuoteIdentifier(name) } diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go index 21e720911..1d5a83911 100644 --- a/lib/sql/escape_test.go +++ b/lib/sql/escape_test.go @@ -7,39 +7,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNeedsEscaping(t *testing.T) { - // BigQuery: - assert.True(t, NeedsEscaping("select", false, constants.BigQuery)) // name that is reserved - assert.True(t, NeedsEscaping("foo", false, constants.BigQuery)) // name that is not reserved - assert.True(t, NeedsEscaping("__artie_foo", false, constants.BigQuery)) // Artie prefix - assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.MSSQL)) // Artie prefix + symbol - - // MS SQL: - assert.True(t, NeedsEscaping("select", false, constants.MSSQL)) // name that is reserved - assert.True(t, NeedsEscaping("foo", false, constants.MSSQL)) // name that is not reserved - assert.True(t, NeedsEscaping("__artie_foo", false, constants.MSSQL)) // Artie prefix - assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.MSSQL)) // Artie prefix + symbol - - // Redshift: - assert.True(t, NeedsEscaping("select", false, constants.Redshift)) // name that is reserved - assert.True(t, NeedsEscaping("truncatecolumns", false, constants.Redshift)) // name that is reserved for Redshift - assert.True(t, NeedsEscaping("foo", false, constants.Redshift)) // name that is not reserved - assert.True(t, NeedsEscaping("__artie_foo", false, constants.Redshift)) // Artie prefix - assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.Redshift)) // Artie prefix + symbol - - // Snowflake (uppercaseEscNames = false): - assert.True(t, NeedsEscaping("select", false, constants.Snowflake)) // name that is reserved - assert.False(t, NeedsEscaping("foo", false, constants.Snowflake)) // name that is not reserved - assert.False(t, NeedsEscaping("__artie_foo", false, constants.Snowflake)) // Artie prefix - assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.Snowflake)) // Artie prefix + symbol - - // Snowflake (uppercaseEscNames = true): - assert.True(t, NeedsEscaping("select", true, constants.Snowflake)) // name that is reserved - assert.True(t, NeedsEscaping("foo", true, constants.Snowflake)) // name that is not reserved - assert.True(t, NeedsEscaping("__artie_foo", true, constants.Snowflake)) // Artie prefix - assert.True(t, NeedsEscaping("__artie_foo:bar", true, constants.Snowflake)) // Artie prefix + symbol -} - func TestEscapeNameIfNecessary(t *testing.T) { type _testCase struct { name string From b2453ba7c98a6dfc066dec9c66eb5ffc9f9fac31 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Wed, 1 May 2024 13:25:48 -0700 Subject: [PATCH 2/2] Move symbolsToEscape --- lib/sql/dialect.go | 3 +++ lib/sql/escape.go | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index b444e2401..4eee6297d 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -45,6 +45,9 @@ type SnowflakeDialect struct { UppercaseEscNames bool } +// symbolsToEscape are additional keywords that we need to escape +var symbolsToEscape = []string{":"} + func (sd SnowflakeDialect) NeedsEscaping(name string) bool { if sd.UppercaseEscNames { // If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix. diff --git a/lib/sql/escape.go b/lib/sql/escape.go index f8d1e62dc..85a94589a 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -4,9 +4,6 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" ) -// symbolsToEscape are additional keywords that we need to escape -var symbolsToEscape = []string{":"} - func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { var dialect = dialectFor(destKind, uppercaseEscNames)