Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sql] Move NeedsEscaping to Dialect #523

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,39 @@ package sql
import (
"fmt"
"log/slog"
"slices"
"strconv"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
)

type Dialect interface {
NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything
QuoteIdentifier(identifier string) string
}

type DefaultDialect struct{}

func (DefaultDialect) NeedsEscaping(_ string) bool { return true }

func (DefaultDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

type BigQueryDialect struct{}

func (BigQueryDialect) NeedsEscaping(_ string) bool { return true }

func (BigQueryDialect) QuoteIdentifier(identifier string) string {
// BigQuery needs backticks to quote.
return fmt.Sprintf("`%s`", identifier)
}

type RedshiftDialect struct{}

func (RedshiftDialect) NeedsEscaping(_ string) bool { return true }

func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
// Preserve the existing behavior of Redshift identifiers being lowercased due to not being quoted.
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
Expand All @@ -34,6 +45,33 @@ type SnowflakeDialect struct {
UppercaseEscNames bool
}

// symbolsToEscape are additional keywords that we need to escape
var symbolsToEscape = []string{":"}

func (sd SnowflakeDialect) NeedsEscaping(name string) bool {
if sd.UppercaseEscNames {
// If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix.
// Since they will be uppercased afer they are escaped then they will result in the same value as if we
// we were to use them in a query without any escaping at all.
return true
} else {
if slices.Contains(constants.ReservedKeywords, name) {
return true
}
// If it does not contain any reserved words, does it contain any symbols that need to be escaped?
for _, symbol := range symbolsToEscape {
if strings.Contains(name, symbol) {
return true
}
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}
}

func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
if sd.UppercaseEscNames {
identifier = strings.ToUpper(identifier)
Expand Down
22 changes: 22 additions & 0 deletions lib/sql/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,28 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) {
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("FOO"))
}

func TestSnowflakeDialect_NeedsEscaping(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}
}

func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) {
{
// UppercaseEscNames enabled:
Expand Down
50 changes: 5 additions & 45 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
@@ -1,57 +1,16 @@
package sql

import (
"log/slog"
"slices"
"strconv"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
)

// symbolsToEscape are additional keywords that we need to escape
var symbolsToEscape = []string{":"}

func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
if NeedsEscaping(name, uppercaseEscNames, destKind) {
return EscapeName(name, uppercaseEscNames, destKind)
}
return name
}
var dialect = dialectFor(destKind, uppercaseEscNames)

func NeedsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool {
switch destKind {
case constants.BigQuery, constants.MSSQL, constants.Redshift:
return true
case constants.S3:
return false
case constants.Snowflake:
if uppercaseEscNames {
// If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix.
// Since they will be uppercased afer they are escaped then they will result in the same value as if we
// we were to use them in a query without any escaping at all.
return true
} else {
if slices.Contains(constants.ReservedKeywords, name) {
return true
}
// If it does not contain any reserved words, does it contain any symbols that need to be escaped?
for _, symbol := range symbolsToEscape {
if strings.Contains(name, symbol) {
return true
}
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
}
default:
slog.Error("Unsupported destination kind", slog.String("destKind", string(destKind)))
return true
if destKind != constants.S3 && dialect.NeedsEscaping(name) {
return dialect.QuoteIdentifier(name)
}

return false
return name
}

func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect {
Expand All @@ -68,5 +27,6 @@ func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dial
}

func EscapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
// TODO: This is only used in one place, remove once [Dialect] has beem added to [Store].
return dialectFor(destKind, uppercaseEscNames).QuoteIdentifier(name)
}
33 changes: 0 additions & 33 deletions lib/sql/escape_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,6 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNeedsEscaping(t *testing.T) {
// BigQuery:
assert.True(t, NeedsEscaping("select", false, constants.BigQuery)) // name that is reserved
assert.True(t, NeedsEscaping("foo", false, constants.BigQuery)) // name that is not reserved
assert.True(t, NeedsEscaping("__artie_foo", false, constants.BigQuery)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.MSSQL)) // Artie prefix + symbol

// MS SQL:
assert.True(t, NeedsEscaping("select", false, constants.MSSQL)) // name that is reserved
assert.True(t, NeedsEscaping("foo", false, constants.MSSQL)) // name that is not reserved
assert.True(t, NeedsEscaping("__artie_foo", false, constants.MSSQL)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.MSSQL)) // Artie prefix + symbol

// Redshift:
assert.True(t, NeedsEscaping("select", false, constants.Redshift)) // name that is reserved
assert.True(t, NeedsEscaping("truncatecolumns", false, constants.Redshift)) // name that is reserved for Redshift
assert.True(t, NeedsEscaping("foo", false, constants.Redshift)) // name that is not reserved
assert.True(t, NeedsEscaping("__artie_foo", false, constants.Redshift)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.Redshift)) // Artie prefix + symbol

// Snowflake (uppercaseEscNames = false):
assert.True(t, NeedsEscaping("select", false, constants.Snowflake)) // name that is reserved
assert.False(t, NeedsEscaping("foo", false, constants.Snowflake)) // name that is not reserved
assert.False(t, NeedsEscaping("__artie_foo", false, constants.Snowflake)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", false, constants.Snowflake)) // Artie prefix + symbol

// Snowflake (uppercaseEscNames = true):
assert.True(t, NeedsEscaping("select", true, constants.Snowflake)) // name that is reserved
assert.True(t, NeedsEscaping("foo", true, constants.Snowflake)) // name that is not reserved
assert.True(t, NeedsEscaping("__artie_foo", true, constants.Snowflake)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", true, constants.Snowflake)) // Artie prefix + symbol
}

func TestEscapeNameIfNecessary(t *testing.T) {
type _testCase struct {
name string
Expand Down