Skip to content

Commit

Permalink
[snowflake] Always escape column identifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
1 parent 1306702 commit 9b90fac
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 81 deletions.
6 changes: 3 additions & 3 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *Store) Label() constants.DestinationKind {
}

func (s *Store) Dialect() sql.Dialect {
return sql.SnowflakeDialect{UppercaseEscNames: s.config.SharedDestinationConfig.UppercaseEscapedNames}
return sql.SnowflakeDialect{LegacyMode: !s.config.SharedDestinationConfig.UppercaseEscapedNames}
}

func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
Expand Down Expand Up @@ -130,12 +130,12 @@ func (s *Store) reestablishConnection() error {
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.Dialect()))
primaryKeysEscaped = append(primaryKeysEscaped, s.Dialect().QuoteIdentifier(pk))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.Dialect()))
orderColsToIterate = append(orderColsToIterate, s.Dialect().QuoteIdentifier(constants.UpdateColumnMarker))
}

var orderByCols []string
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.SnowflakeDialect{UppercaseEscNames: true}
var dialect = sql.SnowflakeDialect{}

type TableIdentifier struct {
database string
Expand Down
6 changes: 3 additions & 3 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -252,7 +252,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -322,7 +322,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down
10 changes: 5 additions & 5 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestMergeStatementSoftDelete(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
for _, idempotentKey := range []string{"", "updated_at"} {
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestMergeStatement(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestMergeStatementCompositeKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("another_id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -249,7 +249,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/dml/merge_valid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestMergeArgument_Valid(t *testing.T) {
primaryKeys := []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}),
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{}),
}

var cols columns.Columns
Expand Down
47 changes: 19 additions & 28 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@ import (
)

type Dialect interface {
NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything
QuoteIdentifier(identifier string) string
EscapeStruct(value string) string
}

type BigQueryDialect struct{}

func (BigQueryDialect) NeedsEscaping(_ string) bool { return true }

func (BigQueryDialect) QuoteIdentifier(identifier string) string {
// BigQuery needs backticks to quote.
return fmt.Sprintf("`%s`", identifier)
Expand All @@ -31,8 +28,6 @@ func (BigQueryDialect) EscapeStruct(value string) string {

type MSSQLDialect struct{}

func (MSSQLDialect) NeedsEscaping(_ string) bool { return true }

func (MSSQLDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}
Expand All @@ -43,8 +38,6 @@ func (MSSQLDialect) EscapeStruct(value string) string {

type RedshiftDialect struct{}

func (RedshiftDialect) NeedsEscaping(_ string) bool { return true }

func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
// Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted.
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
Expand All @@ -55,35 +48,33 @@ func (RedshiftDialect) EscapeStruct(value string) string {
}

type SnowflakeDialect struct {
UppercaseEscNames bool
LegacyMode bool
}

func (sd SnowflakeDialect) NeedsEscaping(name string) bool {
if sd.UppercaseEscNames {
// If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix.
// Since they will be uppercased afer they are escaped then they will result in the same value as if we
// we were to use them in a query without any escaping at all.
func (sd SnowflakeDialect) legacyNeedsEscaping(name string) bool {
if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") {
return true
} else {
if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") {
return true
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}

func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
if sd.UppercaseEscNames {
identifier = strings.ToUpper(identifier)
if sd.LegacyMode {
if sd.legacyNeedsEscaping(identifier) {
// In legacy mode we would have escaped this identifier which would have caused it to be lowercase.
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
)
} else {
// Since this identifier wasn't previously escaped it will have been used uppercase.
identifier = strings.ToUpper(identifier)
}
} else {
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames),
)
identifier = strings.ToUpper(identifier)
}

return fmt.Sprintf(`"%s"`, identifier)
Expand Down
40 changes: 15 additions & 25 deletions lib/sql/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,29 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) {
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO"))
}

func TestSnowflakeDialect_NeedsEscaping(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}
func TestSnowflakeDialect_legacyNeedsEscaping(t *testing.T) {
dialect := SnowflakeDialect{}
assert.True(t, dialect.legacyNeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.legacyNeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.legacyNeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.legacyNeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}
// New mode:
dialect := SnowflakeDialect{LegacyMode: false}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
// Legacy mode:
dialect := SnowflakeDialect{LegacyMode: true}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name
assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name
assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name
assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name
}
}
8 changes: 0 additions & 8 deletions lib/sql/escape.go

This file was deleted.

2 changes: 1 addition & 1 deletion lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (c *Column) RawName() string {

// Name will give you c.name and escape it if necessary.
func (c *Column) Name(dialect sql.Dialect) string {
return sql.EscapeNameIfNecessary(c.name, dialect)
return dialect.QuoteIdentifier(c.name)
}

type Columns struct {
Expand Down
6 changes: 3 additions & 3 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestColumn_Name(t *testing.T) {

assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName)
}
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name)
assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name)
assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name)
}
}
Expand Down Expand Up @@ -486,7 +486,7 @@ func TestColumnsUpdateQuery(t *testing.T) {
{
name: "string and toast",
columns: stringAndToastCols,
dialect: sql.SnowflakeDialect{UppercaseEscNames: true},
dialect: sql.SnowflakeDialect{},
expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`,
},
{
Expand Down
2 changes: 1 addition & 1 deletion lib/typing/columns/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
var dialects = []sql.Dialect{
sql.BigQueryDialect{},
sql.RedshiftDialect{},
sql.SnowflakeDialect{UppercaseEscNames: true},
sql.SnowflakeDialect{},
}

func TestColumn_DefaultValue(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions lib/typing/columns/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestWrapper_Complete(t *testing.T) {

for _, testCase := range testCases {
// Snowflake escape
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})

assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
Expand All @@ -53,7 +53,7 @@ func TestWrapper_Complete(t *testing.T) {
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

{
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
}
{
Expand Down

0 comments on commit 9b90fac

Please sign in to comment.