From 45f0c04b98e2a874d21b8e26ed5da51b3b0886c5 Mon Sep 17 00:00:00 2001 From: Nathan Villaescusa Date: Mon, 22 Apr 2024 10:48:35 -0700 Subject: [PATCH] [sql] Split `EscapeName` into three functions --- clients/bigquery/tableid.go | 2 +- clients/mssql/tableid.go | 2 +- clients/redshift/tableid.go | 2 +- clients/snowflake/tableid.go | 2 +- lib/sql/escape.go | 44 +++++++++++++++++++++-------------- lib/sql/escape_test.go | 4 ++-- lib/typing/columns/columns.go | 2 +- 7 files changed, 33 insertions(+), 25 deletions(-) diff --git a/clients/bigquery/tableid.go b/clients/bigquery/tableid.go index f63429163..0a651a75c 100644 --- a/clients/bigquery/tableid.go +++ b/clients/bigquery/tableid.go @@ -47,6 +47,6 @@ func (ti TableIdentifier) FullyQualifiedName() string { "`%s`.`%s`.%s", ti.projectID, ti.dataset, - sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), + sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}), ) } diff --git a/clients/mssql/tableid.go b/clients/mssql/tableid.go index 9600d62b9..4c44b0c0d 100644 --- a/clients/mssql/tableid.go +++ b/clients/mssql/tableid.go @@ -34,6 +34,6 @@ func (ti TableIdentifier) FullyQualifiedName() string { return fmt.Sprintf( "%s.%s", ti.schema, - sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), + sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}), ) } diff --git a/clients/redshift/tableid.go b/clients/redshift/tableid.go index 5d21ef286..6edf8ca3a 100644 --- a/clients/redshift/tableid.go +++ b/clients/redshift/tableid.go @@ -36,6 +36,6 @@ func (ti TableIdentifier) FullyQualifiedName() string { return fmt.Sprintf( "%s.%s", ti.schema, - sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}), + sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Redshift}), ) } diff --git a/clients/snowflake/tableid.go b/clients/snowflake/tableid.go index 8c313d42f..20064ffef 100644 --- a/clients/snowflake/tableid.go +++ b/clients/snowflake/tableid.go @@ -45,6 +45,6 @@ func (ti TableIdentifier) FullyQualifiedName() string { "%s.%s.%s", ti.database, ti.schema, - sql.EscapeName(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}), + sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.Snowflake}), ) } diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 2e6507209..f640e7e76 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -17,15 +17,11 @@ type NameArgs struct { // symbolsToEscape are additional keywords that we need to escape var symbolsToEscape = []string{":"} -func EscapeName(name string, uppercaseEscNames bool, args *NameArgs) string { - if args == nil || !args.Escape { - return name - } - +func needsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool { var reservedKeywords []string - if args.DestKind == constants.Redshift { + if destKind == constants.Redshift { reservedKeywords = constants.RedshiftReservedKeywords - } else if args.DestKind == constants.MSSQL { + } else if destKind == constants.MSSQL { reservedKeywords = constants.MSSQLReservedKeywords } else { reservedKeywords = constants.ReservedKeywords @@ -50,19 +46,31 @@ func EscapeName(name string, uppercaseEscNames bool, args *NameArgs) string { } } - if needsEscaping { - if uppercaseEscNames { - name = strings.ToUpper(name) - } + return needsEscaping +} - if args.DestKind == constants.BigQuery { - // BigQuery needs backticks to escape. - return fmt.Sprintf("`%s`", name) - } else { - // Snowflake uses quotes. - return fmt.Sprintf(`"%s"`, name) - } +func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string { + if args == nil || !args.Escape { + return name + } + + if needsEscaping(name, uppercaseEscNames, args.DestKind) { + return escapeName(name, uppercaseEscNames, args.DestKind) } return name } + +func escapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { + if uppercaseEscNames { + name = strings.ToUpper(name) + } + + if destKind == constants.BigQuery { + // BigQuery needs backticks to escape. + return fmt.Sprintf("`%s`", name) + } else { + // Snowflake uses quotes. + return fmt.Sprintf(`"%s"`, name) + } +} diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go index 4f44329c5..e16473306 100644 --- a/lib/sql/escape_test.go +++ b/lib/sql/escape_test.go @@ -133,10 +133,10 @@ func TestEscapeName(t *testing.T) { } for _, testCase := range testCases { - actualName := EscapeName(testCase.nameToEscape, false, testCase.args) + actualName := EscapeNameIfNecessary(testCase.nameToEscape, false, testCase.args) assert.Equal(t, testCase.expectedName, actualName, testCase.name) - actualUpperName := EscapeName(testCase.nameToEscape, true, testCase.args) + actualUpperName := EscapeNameIfNecessary(testCase.nameToEscape, true, testCase.args) assert.Equal(t, testCase.expectedNameWhenUpperCfg, actualUpperName, testCase.name) } } diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index 6bc6c39f6..ba238958b 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -87,7 +87,7 @@ func (c *Column) RawName() string { // However, if you pass in escape, we will escape if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. func (c *Column) Name(uppercaseEscNames bool, args *sql.NameArgs) string { - return sql.EscapeName(c.name, uppercaseEscNames, args) + return sql.EscapeNameIfNecessary(c.name, uppercaseEscNames, args) } type Columns struct {