Skip to content

Commit

Permalink
[sql] Split EscapeName into three functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Apr 22, 2024
1 parent 98612c4 commit 45f0c04
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 25 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ func (ti TableIdentifier) FullyQualifiedName() string {
"`%s`.`%s`.%s",
ti.projectID,
ti.dataset,
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
)
}
2 changes: 1 addition & 1 deletion clients/mssql/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
)
}
2 changes: 1 addition & 1 deletion clients/redshift/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}),
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}),
)
}
2 changes: 1 addition & 1 deletion clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ func (ti TableIdentifier) FullyQualifiedName() string {
"%s.%s.%s",
ti.database,
ti.schema,
sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}),
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}),
)
}
44 changes: 26 additions & 18 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,11 @@ type NameArgs struct {
// symbolsToEscape are additional keywords that we need to escape
var symbolsToEscape = []string{":"}

func EscapeName(name string, uppercaseEscNames bool, args *NameArgs) string {
if args == nil || !args.Escape {
return name
}

func needsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool {
var reservedKeywords []string
if args.DestKind == constants.Redshift {
if destKind == constants.Redshift {
reservedKeywords = constants.RedshiftReservedKeywords
} else if args.DestKind == constants.MSSQL {
} else if destKind == constants.MSSQL {
reservedKeywords = constants.MSSQLReservedKeywords
} else {
reservedKeywords = constants.ReservedKeywords
Expand All @@ -50,19 +46,31 @@ func EscapeName(name string, uppercaseEscNames bool, args *NameArgs) string {
}
}

if needsEscaping {
if uppercaseEscNames {
name = strings.ToUpper(name)
}
return needsEscaping
}

if args.DestKind == constants.BigQuery {
// BigQuery needs backticks to escape.
return fmt.Sprintf("`%s`", name)
} else {
// Snowflake uses quotes.
return fmt.Sprintf(`"%s"`, name)
}
func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string {
if args == nil || !args.Escape {
return name
}

if needsEscaping(name, uppercaseEscNames, args.DestKind) {
return escapeName(name, uppercaseEscNames, args.DestKind)
}

return name
}

func escapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
if uppercaseEscNames {
name = strings.ToUpper(name)
}

if destKind == constants.BigQuery {
// BigQuery needs backticks to escape.
return fmt.Sprintf("`%s`", name)
} else {
// Snowflake uses quotes.
return fmt.Sprintf(`"%s"`, name)
}
}
4 changes: 2 additions & 2 deletions lib/sql/escape_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ func TestEscapeName(t *testing.T) {
}

for _, testCase := range testCases {
actualName := EscapeName(testCase.nameToEscape, false, testCase.args)
actualName := EscapeNameIfNecessary(testCase.nameToEscape, false, testCase.args)
assert.Equal(t, testCase.expectedName, actualName, testCase.name)

actualUpperName := EscapeName(testCase.nameToEscape, true, testCase.args)
actualUpperName := EscapeNameIfNecessary(testCase.nameToEscape, true, testCase.args)
assert.Equal(t, testCase.expectedNameWhenUpperCfg, actualUpperName, testCase.name)
}
}
2 changes: 1 addition & 1 deletion lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (c *Column) RawName() string {
// However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations.
// If so, it'll change from `start` => `"start"` as suggested by Snowflake.
func (c *Column) Name(uppercaseEscNames bool, args *sql.NameArgs) string {
return sql.EscapeName(c.name, uppercaseEscNames, args)
return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, args)
}

type Columns struct {
Expand Down

0 comments on commit 45f0c04

Please sign in to comment.