Skip to content

Commit

Permalink
[sql] Add Dialect structs (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 1, 2024
1 parent f673f4c commit db5d399
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 41 deletions.
9 changes: 8 additions & 1 deletion clients/bigquery/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ import (
"fmt"

"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.BigQueryDialect{}

type TableIdentifier struct {
projectID string
dataset string
Expand Down Expand Up @@ -39,5 +42,9 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
func (ti TableIdentifier) FullyQualifiedName() string {
// The fully qualified name for BigQuery is: project_id.dataset.tableName.
// We are escaping the project_id, dataset, and table because there could be special characters.
return fmt.Sprintf("`%s`.`%s`.`%s`", ti.projectID, ti.dataset, ti.table)
return fmt.Sprintf("%s.%s.%s",
dialect.QuoteIdentifier(ti.projectID),
dialect.QuoteIdentifier(ti.dataset),
dialect.QuoteIdentifier(ti.table),
)
}
9 changes: 3 additions & 6 deletions clients/mssql/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package mssql
import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.DefaultDialect{}

type TableIdentifier struct {
schema string
table string
Expand All @@ -30,9 +31,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, false, constants.MSSQL),
)
return fmt.Sprintf("%s.%s", ti.schema, dialect.QuoteIdentifier(ti.table))
}
9 changes: 3 additions & 6 deletions clients/redshift/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package redshift
import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.RedshiftDialect{}

type TableIdentifier struct {
schema string
table string
Expand All @@ -32,9 +33,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
func (ti TableIdentifier) FullyQualifiedName() string {
// Redshift is Postgres compatible, so when establishing a connection, we'll specify a database.
// Thus, we only need to specify schema and table name here.
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeName(ti.table, false, constants.Redshift),
)
return fmt.Sprintf("%s.%s", ti.schema, dialect.QuoteIdentifier(ti.table))
}
10 changes: 3 additions & 7 deletions clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ package snowflake
import (
"fmt"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.SnowflakeDialect{UppercaseEscNames: true}

type TableIdentifier struct {
database string
schema string
Expand Down Expand Up @@ -39,10 +40,5 @@ func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s.%s",
ti.database,
ti.schema,
sql.EscapeName(ti.table, true, constants.Snowflake),
)
return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, dialect.QuoteIdentifier(ti.table))
}
48 changes: 48 additions & 0 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package sql

import (
"fmt"
"log/slog"
"strings"
)

type Dialect interface {
QuoteIdentifier(identifier string) string
}

type DefaultDialect struct{}

func (DefaultDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

type BigQueryDialect struct{}

func (BigQueryDialect) QuoteIdentifier(identifier string) string {
// BigQuery needs backticks to quote.
return fmt.Sprintf("`%s`", identifier)
}

type RedshiftDialect struct{}

func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
// Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted.
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
}

type SnowflakeDialect struct {
UppercaseEscNames bool
}

func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
if sd.UppercaseEscNames {
identifier = strings.ToUpper(identifier)
} else {
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames),
)
}

return fmt.Sprintf(`"%s"`, identifier)
}
40 changes: 40 additions & 0 deletions lib/sql/dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestDefaultDialect_QuoteIdentifier(t *testing.T) {
dialect := DefaultDialect{}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}

func TestBigQueryDialect_QuoteIdentifier(t *testing.T) {
dialect := BigQueryDialect{}
assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo"))
assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO"))
}

func TestRedshiftDialect_QuoteIdentifier(t *testing.T) {
dialect := RedshiftDialect{}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO"))
}

func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
}
34 changes: 13 additions & 21 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sql

import (
"fmt"
"log/slog"
"slices"
"strconv"
Expand Down Expand Up @@ -55,26 +54,19 @@ func NeedsEscaping(name string, uppercaseEscNames bool, destKind constants.Desti
return false
}

func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
if destKind == constants.Snowflake {
if uppercaseEscNames {
name = strings.ToUpper(name)
} else {
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", name),
slog.Bool("uppercaseEscapedNames", uppercaseEscNames),
)
}
} else if destKind == constants.Redshift {
// Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted.
name = strings.ToLower(name)
func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect {
switch destKind {
case constants.BigQuery:
return BigQueryDialect{}
case constants.Snowflake:
return SnowflakeDialect{UppercaseEscNames: uppercaseEscNames}
case constants.Redshift:
return RedshiftDialect{}
default:
return DefaultDialect{}
}
}

if destKind == constants.BigQuery {
// BigQuery needs backticks to escape.
return fmt.Sprintf("`%s`", name)
} else {
// Everything else uses quotes.
return fmt.Sprintf(`"%s"`, name)
}
func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
return dialectFor(destKind, uppercaseEscNames).QuoteIdentifier(name)
}

0 comments on commit db5d399

Please sign in to comment.