Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[snowflake] Always escape column identifiers #541

Merged
merged 6 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (s *Store) Label() constants.DestinationKind {
}

func (s *Store) Dialect() sql.Dialect {
return sql.SnowflakeDialect{UppercaseEscNames: s.config.SharedDestinationConfig.UppercaseEscapedNames}
return sql.SnowflakeDialect{LegacyMode: !s.config.SharedDestinationConfig.UppercaseEscapedNames}
}

func (s *Store) GetConfigMap() *types.DwhToTablesConfigMap {
Expand Down Expand Up @@ -130,12 +130,12 @@ func (s *Store) reestablishConnection() error {
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.Dialect()))
primaryKeysEscaped = append(primaryKeysEscaped, s.Dialect().QuoteIdentifier(pk))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.Dialect()))
orderColsToIterate = append(orderColsToIterate, s.Dialect().QuoteIdentifier(constants.UpdateColumnMarker))
}

var orderByCols []string
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.SnowflakeDialect{UppercaseEscNames: true}
var dialect = sql.SnowflakeDialect{}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't using legacy mode because we don't need to escape any more tables right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we've been escaping all Snowflake tables for some time.


type TableIdentifier struct {
database string
Expand Down
6 changes: 3 additions & 3 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -252,7 +252,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -322,7 +322,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, m.Dialect) {
if col == m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down
10 changes: 5 additions & 5 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestMergeStatementSoftDelete(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
for _, idempotentKey := range []string{"", "updated_at"} {
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestMergeStatement(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestMergeStatementCompositeKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("another_id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -249,7 +249,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/dml/merge_valid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestMergeArgument_Valid(t *testing.T) {
primaryKeys := []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}),
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{}),
}

var cols columns.Columns
Expand Down
47 changes: 19 additions & 28 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,12 @@ import (
)

type Dialect interface {
NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything
QuoteIdentifier(identifier string) string
EscapeStruct(value string) string
}

type BigQueryDialect struct{}

func (BigQueryDialect) NeedsEscaping(_ string) bool { return true }

func (BigQueryDialect) QuoteIdentifier(identifier string) string {
// BigQuery needs backticks to quote.
return fmt.Sprintf("`%s`", identifier)
Expand All @@ -31,8 +28,6 @@ func (BigQueryDialect) EscapeStruct(value string) string {

type MSSQLDialect struct{}

func (MSSQLDialect) NeedsEscaping(_ string) bool { return true }

func (MSSQLDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}
Expand All @@ -43,8 +38,6 @@ func (MSSQLDialect) EscapeStruct(value string) string {

type RedshiftDialect struct{}

func (RedshiftDialect) NeedsEscaping(_ string) bool { return true }

func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
// Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted.
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
Expand All @@ -55,35 +48,33 @@ func (RedshiftDialect) EscapeStruct(value string) string {
}

type SnowflakeDialect struct {
UppercaseEscNames bool
LegacyMode bool
}

func (sd SnowflakeDialect) NeedsEscaping(name string) bool {
if sd.UppercaseEscNames {
// If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix.
// Since they will be uppercased afer they are escaped then they will result in the same value as if we
// we were to use them in a query without any escaping at all.
func (sd SnowflakeDialect) legacyNeedsEscaping(name string) bool {
if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") {
return true
} else {
if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") {
return true
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}

func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
if sd.UppercaseEscNames {
identifier = strings.ToUpper(identifier)
if sd.LegacyMode {
if sd.legacyNeedsEscaping(identifier) {
// In legacy mode we would have escaped this identifier which would have caused it to be lowercase.
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
)
} else {
// Since this identifier wasn't previously escaped it will have been used uppercase.
identifier = strings.ToUpper(identifier)
}
} else {
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames),
)
identifier = strings.ToUpper(identifier)
}

return fmt.Sprintf(`"%s"`, identifier)
Expand Down
43 changes: 18 additions & 25 deletions lib/sql/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,32 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) {
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO"))
}

func TestSnowflakeDialect_NeedsEscaping(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}
func TestSnowflakeDialect_LegacyNeedsEscaping(t *testing.T) {
dialect := SnowflakeDialect{}
assert.True(t, dialect.legacyNeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.legacyNeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.legacyNeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.legacyNeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}
// New mode:
dialect := SnowflakeDialect{LegacyMode: false}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
assert.Equal(t, `"SELECT"`, dialect.QuoteIdentifier("select"))
assert.Equal(t, `"GROUP"`, dialect.QuoteIdentifier("group"))
}
{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
// Legacy mode:
dialect := SnowflakeDialect{LegacyMode: true}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
assert.Equal(t, `"abc:def"`, dialect.QuoteIdentifier("abc:def")) // Symbol
assert.Equal(t, `"select"`, dialect.QuoteIdentifier("select")) // Reserved name
assert.Equal(t, `"order"`, dialect.QuoteIdentifier("order")) // Reserved name
assert.Equal(t, `"group"`, dialect.QuoteIdentifier("group")) // Reserved name
assert.Equal(t, `"start"`, dialect.QuoteIdentifier("start")) // Reserved name
}
}
8 changes: 0 additions & 8 deletions lib/sql/escape.go

This file was deleted.

2 changes: 1 addition & 1 deletion lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (c *Column) RawName() string {

// Name will give you c.name and escape it if necessary.
func (c *Column) Name(dialect sql.Dialect) string {
return sql.EscapeNameIfNecessary(c.name, dialect)
return dialect.QuoteIdentifier(c.name)
}

type Columns struct {
Expand Down
6 changes: 3 additions & 3 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestColumn_Name(t *testing.T) {

assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName)
}
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name)
assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name)
assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name)
}
}
Expand Down Expand Up @@ -486,7 +486,7 @@ func TestColumnsUpdateQuery(t *testing.T) {
{
name: "string and toast",
columns: stringAndToastCols,
dialect: sql.SnowflakeDialect{UppercaseEscNames: true},
dialect: sql.SnowflakeDialect{},
expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`,
},
{
Expand Down
2 changes: 1 addition & 1 deletion lib/typing/columns/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
var dialects = []sql.Dialect{
sql.BigQueryDialect{},
sql.RedshiftDialect{},
sql.SnowflakeDialect{UppercaseEscNames: true},
sql.SnowflakeDialect{},
}

func TestColumn_DefaultValue(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions lib/typing/columns/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestWrapper_Complete(t *testing.T) {

for _, testCase := range testCases {
// Snowflake escape
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})

assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
Expand All @@ -53,7 +53,7 @@ func TestWrapper_Complete(t *testing.T) {
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

{
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
}
{
Expand Down