Skip to content

Commit

Permalink
Merge branch 'master' into guardrails-for-invalid-cols
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Apr 29, 2024
2 parents d4e4471 + 5b9d6c7 commit d3edbce
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 37 deletions.
12 changes: 6 additions & 6 deletions clients/bigquery/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ func (b *BigQueryTestSuite) TestBackfillColumn() {
{
name: "col that has default value that needs to be backfilled (boolean)",
col: needsBackfillCol,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo = true WHERE foo IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo` = true WHERE `foo` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
{
name: "col that has default value that needs to be backfilled (string)",
col: needsBackfillColStr,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo2 = 'hello there' WHERE foo2 IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo2 SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo2` = 'hello there' WHERE `foo2` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo2` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
{
name: "col that has default value that needs to be backfilled (number)",
col: needsBackfillColNum,
backfillSQL: "UPDATE `db`.`public`.`tableName` SET foo3 = 3.5 WHERE foo3 IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN foo3 SET OPTIONS (description=`{\"backfilled\": true}`);",
backfillSQL: "UPDATE `db`.`public`.`tableName` SET `foo3` = 3.5 WHERE `foo3` IS NULL;",
commentSQL: "ALTER TABLE `db`.`public`.`tableName` ALTER COLUMN `foo3` SET OPTIONS (description=`{\"backfilled\": true}`);",
},
}

Expand Down
2 changes: 1 addition & 1 deletion lib/destination/ddl/ddl_alter_delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func (d *DDLTestSuite) TestAlterDelete_Complete() {
execQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0)
var found bool
for key := range allColsMap {
if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN %s", bqName, key) {
if execQuery == fmt.Sprintf("ALTER TABLE %s drop COLUMN `%s`", bqName, key) {
found = true
}
}
Expand Down
31 changes: 17 additions & 14 deletions lib/destination/ddl/ddl_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,30 @@ func (d *DDLTestSuite) Test_CreateTable() {
d.snowflakeStagesStore.GetConfigMap().AddTableToConfig(snowflakeTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))

type dwhToTableConfig struct {
_tableID types.TableIdentifier
_dwh destination.DataWarehouse
_tableConfig *types.DwhTableConfig
_fakeStore *mocks.FakeStore
_tableID types.TableIdentifier
_dwh destination.DataWarehouse
_tableConfig *types.DwhTableConfig
_fakeStore *mocks.FakeStore
_expectedQuery string
}

bigQueryTc := d.bigQueryStore.GetConfigMap().TableConfig(bqTableID)
snowflakeStagesTc := d.snowflakeStagesStore.GetConfigMap().TableConfig(snowflakeTableID)

for _, dwhTc := range []dwhToTableConfig{
{
_tableID: bqTableID,
_dwh: d.bigQueryStore,
_tableConfig: bigQueryTc,
_fakeStore: d.fakeBigQueryStore,
_tableID: bqTableID,
_dwh: d.bigQueryStore,
_tableConfig: bigQueryTc,
_fakeStore: d.fakeBigQueryStore,
_expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (`name` string)", bqTableID.FullyQualifiedName()),
},
{
_tableID: snowflakeTableID,
_dwh: d.snowflakeStagesStore,
_tableConfig: snowflakeStagesTc,
_fakeStore: d.fakeSnowflakeStagesStore,
_tableID: snowflakeTableID,
_dwh: d.snowflakeStagesStore,
_tableConfig: snowflakeStagesTc,
_fakeStore: d.fakeSnowflakeStagesStore,
_expectedQuery: fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", snowflakeTableID.FullyQualifiedName()),
},
} {
alterTableArgs := ddl.AlterTableArgs{
Expand All @@ -66,8 +69,8 @@ func (d *DDLTestSuite) Test_CreateTable() {
assert.Equal(d.T(), 1, dwhTc._fakeStore.ExecCallCount())

query, _ := dwhTc._fakeStore.ExecArgsForCall(0)
assert.Equal(d.T(), query, fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (name string)", dwhTc._tableID.FullyQualifiedName()), query)
assert.Equal(d.T(), false, dwhTc._tableConfig.CreateTable())
assert.Equal(d.T(), dwhTc._expectedQuery, query)
assert.False(d.T(), dwhTc._tableConfig.CreateTable())
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/destination/ddl/ddl_temp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,6 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() {
assert.Equal(d.T(), 1, d.fakeBigQueryStore.ExecCallCount())
bqQuery, _ := d.fakeBigQueryStore.ExecArgsForCall(0)
// Cutting off the expiration_timestamp since it's time based.
assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (foo string,bar float64,`select` string) OPTIONS (expiration_timestamp =")
assert.Contains(d.T(), bqQuery, "CREATE TABLE IF NOT EXISTS `db`.`schema`.`tempTableName` (`foo` string,`bar` float64,`select` string) OPTIONS (expiration_timestamp =")
}
}
4 changes: 2 additions & 2 deletions lib/destination/dml/merge_bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestMergeStatement_TempTable(t *testing.T) {
mergeSQL, err := mergeArg.GetStatement()
assert.NoError(t, err)

assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.order_id = cc.order_id", mergeSQL)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON c.`order_id` = cc.`order_id`", mergeSQL)
}

func TestMergeStatement_JSONKey(t *testing.T) {
Expand All @@ -50,5 +50,5 @@ func TestMergeStatement_JSONKey(t *testing.T) {

mergeSQL, err := mergeArg.GetStatement()
assert.NoError(t, err)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.order_oid) = TO_JSON_STRING(cc.order_oid)", mergeSQL)
assert.Contains(t, mergeSQL, "MERGE INTO customers.orders c USING customers.orders_tmp AS cc ON TO_JSON_STRING(c.`order_oid`) = TO_JSON_STRING(cc.`order_oid`)", mergeSQL)
}
7 changes: 5 additions & 2 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ func NeedsEscaping(name string, destKind constants.DestinationKind) bool {
var reservedKeywords []string
if destKind == constants.Redshift {
reservedKeywords = constants.RedshiftReservedKeywords
} else if destKind == constants.MSSQL {
return !strings.HasPrefix(name, constants.ArtiePrefix)
} else if destKind == constants.MSSQL || destKind == constants.BigQuery {
// TODO: Escape names that start with [constants.ArtiePrefix].
if !strings.HasPrefix(name, constants.ArtiePrefix) {
return true
}
} else {
reservedKeywords = constants.ReservedKeywords
}
Expand Down
29 changes: 27 additions & 2 deletions lib/sql/escape_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNeedsEscaping(t *testing.T) {
// BigQuery:
assert.True(t, NeedsEscaping("select", constants.BigQuery)) // name that is reserved
assert.True(t, NeedsEscaping("foo", constants.BigQuery)) // name that is not reserved
assert.False(t, NeedsEscaping("__artie_foo", constants.BigQuery)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", constants.MSSQL)) // Artie prefix + symbol

// MS SQL:
assert.True(t, NeedsEscaping("select", constants.MSSQL)) // name that is reserved
assert.True(t, NeedsEscaping("foo", constants.MSSQL)) // name that is not reserved
assert.False(t, NeedsEscaping("__artie_foo", constants.MSSQL)) // Artie prefix
assert.True(t, NeedsEscaping("__artie_foo:bar", constants.MSSQL)) // Artie prefix + symbol

// Redshift:
assert.True(t, NeedsEscaping("select", constants.Redshift)) // name that is reserved
assert.True(t, NeedsEscaping("truncatecolumns", constants.Redshift)) // name that is reserved for Redshift
assert.False(t, NeedsEscaping("foo", constants.Redshift)) // name that is not reserved
assert.False(t, NeedsEscaping("__artie_foo", constants.Redshift)) // Artie prefix

// Snowflake:
assert.True(t, NeedsEscaping("select", constants.Snowflake)) // name that is reserved
assert.False(t, NeedsEscaping("foo", constants.Snowflake)) // name that is not reserved
assert.False(t, NeedsEscaping("__artie_foo", constants.Snowflake)) // Artie prefix
}

func TestEscapeNameIfNecessary(t *testing.T) {
type _testCase struct {
name string
Expand Down Expand Up @@ -56,8 +81,8 @@ func TestEscapeNameIfNecessary(t *testing.T) {
name: "bigquery, #2",
destKind: constants.BigQuery,
nameToEscape: "hello",
expectedName: "hello",
expectedNameWhenUpperCfg: "hello",
expectedName: "`hello`",
expectedNameWhenUpperCfg: "`HELLO`",
},
{
name: "redshift, #1 (delta)",
Expand Down
14 changes: 7 additions & 7 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ func TestColumn_Name(t *testing.T) {
colName: "foo",
expectedName: "foo",
expectedNameEsc: "foo",
expectedNameEscBq: "foo",
expectedNameEscBq: "`foo`",
},
{
colName: "bar",
expectedName: "bar",
expectedNameEsc: "bar",
expectedNameEscBq: "bar",
expectedNameEscBq: "`bar`",
},
}

Expand Down Expand Up @@ -266,13 +266,13 @@ func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
name: "happy path",
cols: happyPathCols,
expectedColsEsc: []string{"hi", "bye", `"START"`},
expectedColsEscBq: []string{"hi", "bye", "`start`"},
expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"},
},
{
name: "happy path + extra col",
cols: extraCols,
expectedColsEsc: []string{"hi", "bye", `"START"`},
expectedColsEscBq: []string{"hi", "bye", "`start`"},
expectedColsEscBq: []string{"`hi`", "`bye`", "`start`"},
},
}

Expand Down Expand Up @@ -498,21 +498,21 @@ func TestColumnsUpdateQuery(t *testing.T) {
name: "struct, string and toast string (bigquery)",
columns: lastCaseColTypes,
destKind: constants.BigQuery,
expectedString: `a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '{"key":"__debezium_unavailable_value"}', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3`,
expectedString: "`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '{\"key\":\"__debezium_unavailable_value\"}', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`",
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
destKind: constants.BigQuery,
expectedString: fmt.Sprintf(`a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '%s', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3,%s,%s`,
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`"),
skipDeleteCol: true,
},
{
name: "struct, string and toast string (bigquery) w/ reserved keywords",
columns: lastCaseEscapeTypes,
destKind: constants.BigQuery,
expectedString: fmt.Sprintf(`a1= CASE WHEN COALESCE(TO_JSON_STRING(cc.a1) != '%s', true) THEN cc.a1 ELSE c.a1 END,b2= CASE WHEN COALESCE(cc.b2 != '__debezium_unavailable_value', true) THEN cc.b2 ELSE c.b2 END,c3=cc.c3,%s,%s`,
expectedString: fmt.Sprintf("`a1`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`a1`) != '%s', true) THEN cc.`a1` ELSE c.`a1` END,`b2`= CASE WHEN COALESCE(cc.`b2` != '__debezium_unavailable_value', true) THEN cc.`b2` ELSE c.`b2` END,`c3`=cc.`c3`,%s,%s",
key, fmt.Sprintf("`start`= CASE WHEN COALESCE(TO_JSON_STRING(cc.`start`) != '%s', true) THEN cc.`start` ELSE c.`start` END", key), "`select`=cc.`select`,__artie_delete=cc.__artie_delete"),
skipDeleteCol: false,
},
Expand Down
4 changes: 2 additions & 2 deletions lib/typing/columns/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ func TestWrapper_Complete(t *testing.T) {
name: "happy",
expectedRawName: "happy",
expectedEscapedName: "happy",
expectedEscapedNameBQ: "happy",
expectedEscapedNameBQ: "`happy`",
},
{
name: "user_id",
expectedRawName: "user_id",
expectedEscapedName: "user_id",
expectedEscapedNameBQ: "user_id",
expectedEscapedNameBQ: "`user_id`",
},
{
name: "group",
Expand Down

0 comments on commit d3edbce

Please sign in to comment.