From db5d3995d979fd23b91161fdfe85307f5f2b2471 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 1 May 2024 13:09:44 -0700 Subject: [PATCH] [sql] Add `Dialect` structs (#522) --- clients/bigquery/tableid.go | 9 ++++++- clients/mssql/tableid.go | 9 +++---- clients/redshift/tableid.go | 9 +++---- clients/snowflake/tableid.go | 10 +++----- lib/sql/dialect.go | 48 ++++++++++++++++++++++++++++++++++++ lib/sql/dialect_test.go | 40 ++++++++++++++++++++++++++++++ lib/sql/escape.go | 34 ++++++++++--------------- 7 files changed, 118 insertions(+), 41 deletions(-) create mode 100644 lib/sql/dialect.go create mode 100644 lib/sql/dialect_test.go diff --git a/clients/bigquery/tableid.go b/clients/bigquery/tableid.go index 5c091e1b5..f786524ea 100644 --- a/clients/bigquery/tableid.go +++ b/clients/bigquery/tableid.go @@ -4,8 +4,11 @@ import ( "fmt" "github.com/artie-labs/transfer/lib/destination/types" + "github.com/artie-labs/transfer/lib/sql" ) +var dialect = sql.BigQueryDialect{} + type TableIdentifier struct { projectID string dataset string @@ -39,5 +42,9 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { func (ti TableIdentifier) FullyQualifiedName() string { // The fully qualified name for BigQuery is: project_id.dataset.tableName. // We are escaping the project_id, dataset, and table because there could be special characters. - return fmt.Sprintf("`%s`.`%s`.`%s`", ti.projectID, ti.dataset, ti.table) + return fmt.Sprintf("%s.%s.%s", + dialect.QuoteIdentifier(ti.projectID), + dialect.QuoteIdentifier(ti.dataset), + dialect.QuoteIdentifier(ti.table), + ) } diff --git a/clients/mssql/tableid.go b/clients/mssql/tableid.go index 09b4ee8d2..b5026d723 100644 --- a/clients/mssql/tableid.go +++ b/clients/mssql/tableid.go @@ -3,11 +3,12 @@ package mssql import ( "fmt" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/sql" ) +var dialect = sql.DefaultDialect{} + type TableIdentifier struct { schema string table string @@ -30,9 +31,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { } func (ti TableIdentifier) FullyQualifiedName() string { - return fmt.Sprintf( - "%s.%s", - ti.schema, - sql.EscapeName(ti.table, false, constants.MSSQL), - ) + return fmt.Sprintf("%s.%s", ti.schema, dialect.QuoteIdentifier(ti.table)) } diff --git a/clients/redshift/tableid.go b/clients/redshift/tableid.go index aea58d7e3..cf2c0e929 100644 --- a/clients/redshift/tableid.go +++ b/clients/redshift/tableid.go @@ -3,11 +3,12 @@ package redshift import ( "fmt" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/sql" ) +var dialect = sql.RedshiftDialect{} + type TableIdentifier struct { schema string table string @@ -32,9 +33,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { func (ti TableIdentifier) FullyQualifiedName() string { // Redshift is Postgres compatible, so when establishing a connection, we'll specify a database. // Thus, we only need to specify schema and table name here. - return fmt.Sprintf( - "%s.%s", - ti.schema, - sql.EscapeName(ti.table, false, constants.Redshift), - ) + return fmt.Sprintf("%s.%s", ti.schema, dialect.QuoteIdentifier(ti.table)) } diff --git a/clients/snowflake/tableid.go b/clients/snowflake/tableid.go index 057129194..662b97f75 100644 --- a/clients/snowflake/tableid.go +++ b/clients/snowflake/tableid.go @@ -3,11 +3,12 @@ package snowflake import ( "fmt" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/destination/types" "github.com/artie-labs/transfer/lib/sql" ) +var dialect = sql.SnowflakeDialect{UppercaseEscNames: true} + type TableIdentifier struct { database string schema string @@ -39,10 +40,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier { } func (ti TableIdentifier) FullyQualifiedName() string { - return fmt.Sprintf( - "%s.%s.%s", - ti.database, - ti.schema, - sql.EscapeName(ti.table, true, constants.Snowflake), - ) + return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, dialect.QuoteIdentifier(ti.table)) } diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go new file mode 100644 index 000000000..780b72797 --- /dev/null +++ b/lib/sql/dialect.go @@ -0,0 +1,48 @@ +package sql + +import ( + "fmt" + "log/slog" + "strings" +) + +type Dialect interface { + QuoteIdentifier(identifier string) string +} + +type DefaultDialect struct{} + +func (DefaultDialect) QuoteIdentifier(identifier string) string { + return fmt.Sprintf(`"%s"`, identifier) +} + +type BigQueryDialect struct{} + +func (BigQueryDialect) QuoteIdentifier(identifier string) string { + // BigQuery needs backticks to quote. + return fmt.Sprintf("`%s`", identifier) +} + +type RedshiftDialect struct{} + +func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { + // Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted. + return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) +} + +type SnowflakeDialect struct { + UppercaseEscNames bool +} + +func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { + if sd.UppercaseEscNames { + identifier = strings.ToUpper(identifier) + } else { + slog.Warn("Escaped Snowflake identifier is not being uppercased", + slog.String("name", identifier), + slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames), + ) + } + + return fmt.Sprintf(`"%s"`, identifier) +} diff --git a/lib/sql/dialect_test.go b/lib/sql/dialect_test.go new file mode 100644 index 000000000..3a779cf74 --- /dev/null +++ b/lib/sql/dialect_test.go @@ -0,0 +1,40 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultDialect_QuoteIdentifier(t *testing.T) { + dialect := DefaultDialect{} + assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo")) + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) +} + +func TestBigQueryDialect_QuoteIdentifier(t *testing.T) { + dialect := BigQueryDialect{} + assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo")) + assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO")) +} + +func TestRedshiftDialect_QuoteIdentifier(t *testing.T) { + dialect := RedshiftDialect{} + assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo")) + assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO")) +} + +func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) { + { + // UppercaseEscNames enabled: + dialect := SnowflakeDialect{UppercaseEscNames: true} + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo")) + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) + } + { + // UppercaseEscNames disabled: + dialect := SnowflakeDialect{UppercaseEscNames: false} + assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo")) + assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO")) + } +} diff --git a/lib/sql/escape.go b/lib/sql/escape.go index f6081c43a..28c22d499 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -1,7 +1,6 @@ package sql import ( - "fmt" "log/slog" "slices" "strconv" @@ -55,26 +54,19 @@ func NeedsEscaping(name string, uppercaseEscNames bool, destKind constants.Desti return false } -func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { - if destKind == constants.Snowflake { - if uppercaseEscNames { - name = strings.ToUpper(name) - } else { - slog.Warn("Escaped Snowflake identifier is not being uppercased", - slog.String("name", name), - slog.Bool("uppercaseEscapedNames", uppercaseEscNames), - ) - } - } else if destKind == constants.Redshift { - // Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted. - name = strings.ToLower(name) +func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect { + switch destKind { + case constants.BigQuery: + return BigQueryDialect{} + case constants.Snowflake: + return SnowflakeDialect{UppercaseEscNames: uppercaseEscNames} + case constants.Redshift: + return RedshiftDialect{} + default: + return DefaultDialect{} } +} - if destKind == constants.BigQuery { - // BigQuery needs backticks to escape. - return fmt.Sprintf("`%s`", name) - } else { - // Everything else uses quotes. - return fmt.Sprintf(`"%s"`, name) - } +func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string { + return dialectFor(destKind, uppercaseEscNames).QuoteIdentifier(name) }