Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
1 parent 95dd0da commit 3ca6917
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 93 deletions.
16 changes: 8 additions & 8 deletions clients/snowflake/snowflake_dedupe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC) = 2)`, stagingTableID.FullyQualifiedName()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "ID" ORDER BY "ID" ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1."ID" = t2."ID"`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
Expand All @@ -35,10 +35,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY id ORDER BY id ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."CUSTOMERS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "ID" ORDER BY "ID" ASC, "__ARTIE_UPDATED_AT" ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1.id = t2.id`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."CUSTOMERS" t1 USING %s t2 WHERE t1."ID" = t2."ID"`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."CUSTOMERS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
Expand All @@ -50,10 +50,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC) = 2)`, stagingTableID.FullyQualifiedName()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "USER_ID", "SETTINGS" ORDER BY "USER_ID" ASC, "SETTINGS" ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1."USER_ID" = t2."USER_ID" AND t1."SETTINGS" = t2."SETTINGS"`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
{
Expand All @@ -65,10 +65,10 @@ func (s *SnowflakeTestSuite) TestGenerateDedupeQueries() {
assert.Len(s.T(), parts, 3)
assert.Equal(
s.T(),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, settings ORDER BY user_id ASC, settings ASC, __artie_updated_at ASC) = 2)`, stagingTableID.FullyQualifiedName()),
fmt.Sprintf(`CREATE OR REPLACE TRANSIENT TABLE %s AS (SELECT * FROM db.public."USER_SETTINGS" QUALIFY ROW_NUMBER() OVER (PARTITION BY "USER_ID", "SETTINGS" ORDER BY "USER_ID" ASC, "SETTINGS" ASC, "__ARTIE_UPDATED_AT" ASC) = 2)`, stagingTableID.FullyQualifiedName()),
parts[0],
)
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1.user_id = t2.user_id AND t1.settings = t2.settings`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`DELETE FROM db.public."USER_SETTINGS" t1 USING %s t2 WHERE t1."USER_ID" = t2."USER_ID" AND t1."SETTINGS" = t2."SETTINGS"`, stagingTableID.FullyQualifiedName()), parts[1])
assert.Equal(s.T(), fmt.Sprintf(`INSERT INTO db.public."USER_SETTINGS" SELECT * FROM %s`, stagingTableID.FullyQualifiedName()), parts[2])
}
}
12 changes: 6 additions & 6 deletions clients/snowflake/staging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ func (s *SnowflakeTestSuite) TestBackfillColumn() {
{
name: "col that has default value that needs to be backfilled",
col: needsBackfillCol,
backfillSQL: `UPDATE db.public."TABLENAME" SET foo = true WHERE foo IS NULL;`,
commentSQL: `COMMENT ON COLUMN db.public."TABLENAME".foo IS '{"backfilled": true}';`,
backfillSQL: `UPDATE db.public."TABLENAME" SET "FOO" = true WHERE "FOO" IS NULL;`,
commentSQL: `COMMENT ON COLUMN db.public."TABLENAME"."FOO" IS '{"backfilled": true}';`,
},
{
name: "default col that has default value that needs to be backfilled",
col: needsBackfillColDefault,
backfillSQL: `UPDATE db.public."TABLENAME" SET default = true WHERE "DEFAULT" IS NULL;`,
commentSQL: `COMMENT ON COLUMN db.public."TABLENAME".default IS '{"backfilled": true}';`,
backfillSQL: `UPDATE db.public."TABLENAME" SET "DEFAULT" = true WHERE "DEFAULT" IS NULL;`,
commentSQL: `COMMENT ON COLUMN db.public."TABLENAME"."DEFAULT" IS '{"backfilled": true}';`,
},
}

Expand Down Expand Up @@ -147,7 +147,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() {
createQuery, _ := s.fakeStageStore.ExecArgsForCall(0)

prefixQuery := fmt.Sprintf(
`CREATE TABLE IF NOT EXISTS %s (user_id string,first_name string,last_name string,dusty string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName)
`CREATE TABLE IF NOT EXISTS %s ("USER_ID" string,"FIRST_NAME" string,"LAST_NAME" string,"DUSTY" string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName)
containsPrefix := strings.HasPrefix(createQuery, prefixQuery)
assert.True(s.T(), containsPrefix, fmt.Sprintf("createQuery:%v, prefixQuery:%s", createQuery, prefixQuery))
resourceName := addPrefixToTableName(tempTableID, "%")
Expand All @@ -157,7 +157,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() {
assert.Contains(s.T(), putQuery, fmt.Sprintf("@%s AUTO_COMPRESS=TRUE", resourceName))
// Third call is a COPY INTO
copyQuery, _ := s.fakeStageStore.ExecArgsForCall(2)
assert.Equal(s.T(), fmt.Sprintf(`COPY INTO %s (user_id,first_name,last_name,dusty) FROM (SELECT $1,$2,$3,$4 FROM @%s)`,
assert.Equal(s.T(), fmt.Sprintf(`COPY INTO %s ("USER_ID","FIRST_NAME","LAST_NAME","DUSTY") FROM (SELECT $1,$2,$3,$4 FROM @%s)`,
tempTableName, resourceName), copyQuery)
}
{
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/artie-labs/transfer/lib/sql"
)

var dialect = sql.SnowflakeDialect{UppercaseEscNames: true}
var dialect = sql.SnowflakeDialect{}

type TableIdentifier struct {
database string
Expand Down
3 changes: 0 additions & 3 deletions lib/destination/ddl/ddl_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ func (d *DDLTestSuite) SetupTest() {
snowflakeStagesStore := db.Store(d.fakeSnowflakeStagesStore)
snowflakeCfg := config.Config{
Snowflake: &config.Snowflake{},
SharedDestinationConfig: config.SharedDestinationConfig{
UppercaseEscapedNames: true,
},
}
d.snowflakeStagesStore, err = snowflake.LoadSnowflake(snowflakeCfg, &snowflakeStagesStore)
assert.NoError(d.T(), err)
Expand Down
10 changes: 5 additions & 5 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestMergeStatementSoftDelete(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
for _, idempotentKey := range []string{"", "updated_at"} {
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
Expand Down Expand Up @@ -107,7 +107,7 @@ func TestMergeStatement(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestMergeStatementCompositeKey(t *testing.T) {
_cols.AddColumn(columns.NewColumn("another_id", typing.String))
_cols.AddColumn(columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down Expand Up @@ -249,7 +249,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) {
subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)",
strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ","))

dialect := sql.SnowflakeDialect{UppercaseEscNames: true}
dialect := sql.SnowflakeDialect{}
mergeArg := MergeArgument{
TableID: MockTableIdentifier{fqTable},
SubQuery: subQuery,
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/dml/merge_valid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestMergeArgument_Valid(t *testing.T) {
primaryKeys := []columns.Wrapper{
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{UppercaseEscNames: true}),
columns.NewWrapper(columns.NewColumn("id", typing.Integer), sql.SnowflakeDialect{}),
}

var cols columns.Columns
Expand Down
35 changes: 3 additions & 32 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ package sql

import (
"fmt"
"log/slog"
"slices"
"strconv"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/stringutil"
)

Expand Down Expand Up @@ -55,39 +51,14 @@ func (RedshiftDialect) EscapeStruct(value any) string {
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false))
}

type SnowflakeDialect struct {
UppercaseEscNames bool
}
type SnowflakeDialect struct{}

func (sd SnowflakeDialect) NeedsEscaping(name string) bool {
if sd.UppercaseEscNames {
// If uppercaseEscNames is true then we will escape all identifiers that do not start with the Artie priefix.
// Since they will be uppercased afer they are escaped then they will result in the same value as if we
// we were to use them in a query without any escaping at all.
return true
} else {
if slices.Contains(constants.ReservedKeywords, name) || strings.Contains(name, ":") {
return true
}
// If it still doesn't need to be escaped, we should check if it's a number.
if _, err := strconv.Atoi(name); err == nil {
return true
}
return false
}
return true
}

func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
if sd.UppercaseEscNames {
identifier = strings.ToUpper(identifier)
} else {
slog.Warn("Escaped Snowflake identifier is not being uppercased",
slog.String("name", identifier),
slog.Bool("uppercaseEscapedNames", sd.UppercaseEscNames),
)
}

return fmt.Sprintf(`"%s"`, identifier)
return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier))
}

func (SnowflakeDialect) EscapeStruct(value any) string {
Expand Down
42 changes: 11 additions & 31 deletions lib/sql/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,18 @@ func TestRedshiftDialect_QuoteIdentifier(t *testing.T) {
}

func TestSnowflakeDialect_NeedsEscaping(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.False(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.False(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{}

assert.True(t, dialect.NeedsEscaping("select")) // name that is reserved
assert.True(t, dialect.NeedsEscaping("foo")) // name that is not reserved
assert.True(t, dialect.NeedsEscaping("__artie_foo")) // Artie prefix
assert.True(t, dialect.NeedsEscaping("__artie_foo:bar")) // Artie prefix + symbol
}

func TestSnowflakeDialect_QuoteIdentifier(t *testing.T) {
{
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{UppercaseEscNames: true}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
{
// UppercaseEscNames disabled:
dialect := SnowflakeDialect{UppercaseEscNames: false}
assert.Equal(t, `"foo"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
// UppercaseEscNames enabled:
dialect := SnowflakeDialect{}
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("foo"))
assert.Equal(t, `"FOO"`, dialect.QuoteIdentifier("FOO"))
}
6 changes: 3 additions & 3 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestColumn_Name(t *testing.T) {

assert.Equal(t, testCase.expectedName, col.RawName(), testCase.colName)

assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.colName)
assert.Equal(t, testCase.expectedNameEsc, col.Name(sql.SnowflakeDialect{}), testCase.colName)
assert.Equal(t, testCase.expectedNameEscBq, col.Name(sql.BigQueryDialect{}), testCase.colName)
}
}
Expand Down Expand Up @@ -282,7 +282,7 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{UppercaseEscNames: true}), testCase.name)
assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(sql.SnowflakeDialect{}), testCase.name)
assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(sql.BigQueryDialect{}), testCase.name)
}
}
Expand Down Expand Up @@ -486,7 +486,7 @@ func TestColumnsUpdateQuery(t *testing.T) {
{
name: "string and toast",
columns: stringAndToastCols,
dialect: sql.SnowflakeDialect{UppercaseEscNames: true},
dialect: sql.SnowflakeDialect{},
expectedString: `"FOO"= CASE WHEN COALESCE(cc."FOO" != '__debezium_unavailable_value', true) THEN cc."FOO" ELSE c."FOO" END,"BAR"=cc."BAR"`,
},
{
Expand Down
2 changes: 1 addition & 1 deletion lib/typing/columns/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
var dialects = []sql.Dialect{
sql.BigQueryDialect{},
sql.RedshiftDialect{},
sql.SnowflakeDialect{UppercaseEscNames: true},
sql.SnowflakeDialect{},
}

func TestColumn_DefaultValue(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions lib/typing/columns/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestWrapper_Complete(t *testing.T) {

for _, testCase := range testCases {
// Snowflake escape
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w := NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})

assert.Equal(t, testCase.expectedEscapedName, w.EscapedName(), testCase.name)
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
Expand All @@ -53,7 +53,7 @@ func TestWrapper_Complete(t *testing.T) {
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)

{
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{UppercaseEscNames: true})
w = NewWrapper(NewColumn(testCase.name, typing.Invalid), sql.SnowflakeDialect{})
assert.Equal(t, testCase.expectedRawName, w.RawName(), testCase.name)
}
{
Expand Down

0 comments on commit 3ca6917

Please sign in to comment.