Skip to content

Commit

Permalink
Add Dialect to MergeArgument
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 1, 2024
1 parent 78b2eb2 commit 9026cbe
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 5 deletions.
1 change: 1 addition & 0 deletions clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
Columns: tableData.ReadOnlyInMemoryCols(),
SoftDelete: tableData.TopicConfig().SoftDelete,
DestKind: dwh.Label(),
Dialect: dwh.Dialect(),
UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()),
ContainsHardDeletes: ptr.ToBool(tableData.ContainsHardDeletes()),
}
Expand Down
4 changes: 2 additions & 2 deletions clients/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ func (s *Store) reestablishConnection() error {
func (s *Store) generateDedupeQueries(tableID, stagingTableID types.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string {
var primaryKeysEscaped []string
for _, pk := range primaryKeys {
primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessary(pk, s.ShouldUppercaseEscapedNames(), s.Label()))
primaryKeysEscaped = append(primaryKeysEscaped, sql.EscapeNameIfNecessaryUsingDialect(pk, s.Dialect()))
}

orderColsToIterate := primaryKeysEscaped
if topicConfig.IncludeArtieUpdatedAt {
orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessary(constants.UpdateColumnMarker, s.ShouldUppercaseEscapedNames(), s.Label()))
orderColsToIterate = append(orderColsToIterate, sql.EscapeNameIfNecessaryUsingDialect(constants.UpdateColumnMarker, s.Dialect()))
}

var orderByCols []string
Expand Down
11 changes: 8 additions & 3 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type MergeArgument struct {
// where we do not issue a DELETE statement if there are no hard deletes in the batch
ContainsHardDeletes *bool
UppercaseEscNames *bool
Dialect sql.Dialect
}

func (m *MergeArgument) Valid() error {
Expand Down Expand Up @@ -62,6 +63,10 @@ func (m *MergeArgument) Valid() error {
return fmt.Errorf("invalid destination: %s", m.DestKind)
}

if m.Dialect == nil {
return fmt.Errorf("dialect cannot be nil")
}

return nil
}

Expand Down Expand Up @@ -129,7 +134,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) {
if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -252,7 +257,7 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) {
if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down Expand Up @@ -322,7 +327,7 @@ WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`,
// We also need to remove __artie flags since it does not exist in the destination table
var removed bool
for idx, col := range cols {
if col == sql.EscapeNameIfNecessary(constants.DeleteColumnMarker, *m.UppercaseEscNames, m.DestKind) {
if col == sql.EscapeNameIfNecessaryUsingDialect(constants.DeleteColumnMarker, m.Dialect) {
cols = append(cols[:idx], cols[idx+1:]...)
removed = true
break
Expand Down
3 changes: 3 additions & 0 deletions lib/destination/dml/merge_bigquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/stretchr/testify/assert"
Expand All @@ -22,6 +23,7 @@ func TestMergeStatement_TempTable(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_id", typing.Invalid), false, constants.BigQuery)},
Columns: &cols,
DestKind: constants.BigQuery,
Dialect: sql.BigQueryDialect{},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(false),
}
Expand All @@ -44,6 +46,7 @@ func TestMergeStatement_JSONKey(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("order_oid", typing.Invalid), false, constants.BigQuery)},
Columns: &cols,
DestKind: constants.BigQuery,
Dialect: sql.BigQueryDialect{},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down
2 changes: 2 additions & 0 deletions lib/destination/dml/merge_mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -47,6 +48,7 @@ func Test_GetMSSQLStatement(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.MSSQL)},
Columns: &_cols,
DestKind: constants.MSSQL,
Dialect: sql.DefaultDialect{},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down
8 changes: 8 additions & 0 deletions lib/destination/dml/merge_parts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/typing"

Expand Down Expand Up @@ -72,6 +73,7 @@ func TestMergeStatementParts_SkipDelete(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
ContainsHardDeletes: ptr.ToBool(false),
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down Expand Up @@ -99,6 +101,7 @@ func TestMergeStatementPartsSoftDelete(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
SoftDelete: true,
UppercaseEscNames: ptr.ToBool(false),
ContainsHardDeletes: ptr.ToBool(false),
Expand Down Expand Up @@ -139,6 +142,7 @@ func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
SoftDelete: true,
UppercaseEscNames: ptr.ToBool(false),
ContainsHardDeletes: ptr.ToBool(false),
Expand Down Expand Up @@ -182,6 +186,7 @@ func TestMergeStatementParts(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
ContainsHardDeletes: ptr.ToBool(true),
UppercaseEscNames: ptr.ToBool(false),
}
Expand All @@ -208,6 +213,7 @@ func TestMergeStatementParts(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
IdempotentKey: "created_at",
ContainsHardDeletes: ptr.ToBool(true),
UppercaseEscNames: ptr.ToBool(false),
Expand Down Expand Up @@ -240,6 +246,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
ContainsHardDeletes: ptr.ToBool(true),
UppercaseEscNames: ptr.ToBool(false),
}
Expand All @@ -266,6 +273,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) {
PrimaryKeys: res.PrimaryKeys,
Columns: &res.ColumnsToTypes,
DestKind: constants.Redshift,
Dialect: sql.RedshiftDialect{},
ContainsHardDeletes: ptr.ToBool(true),
IdempotentKey: "created_at",
UppercaseEscNames: ptr.ToBool(false),
Expand Down
6 changes: 6 additions & 0 deletions lib/destination/dml/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)
Expand Down Expand Up @@ -64,6 +65,7 @@ func TestMergeStatementSoftDelete(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
Dialect: sql.SnowflakeDialect{UppercaseEscNames: false},
SoftDelete: true,
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down Expand Up @@ -113,6 +115,7 @@ func TestMergeStatement(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), true, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
Dialect: sql.SnowflakeDialect{UppercaseEscNames: true},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(true),
}
Expand Down Expand Up @@ -161,6 +164,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) {
PrimaryKeys: []columns.Wrapper{columns.NewWrapper(columns.NewColumn("id", typing.Invalid), false, constants.Snowflake)},
Columns: &_cols,
DestKind: constants.Snowflake,
Dialect: sql.SnowflakeDialect{UppercaseEscNames: false},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down Expand Up @@ -206,6 +210,7 @@ func TestMergeStatementCompositeKey(t *testing.T) {
},
Columns: &_cols,
DestKind: constants.Snowflake,
Dialect: sql.SnowflakeDialect{UppercaseEscNames: false},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(false),
}
Expand Down Expand Up @@ -255,6 +260,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) {
},
Columns: &_cols,
DestKind: constants.Snowflake,
Dialect: sql.SnowflakeDialect{UppercaseEscNames: true},
SoftDelete: false,
UppercaseEscNames: ptr.ToBool(true),
}
Expand Down
14 changes: 14 additions & 0 deletions lib/destination/dml/merge_valid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"

"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing"
Expand Down Expand Up @@ -90,6 +91,18 @@ func TestMergeArgument_Valid(t *testing.T) {
},
expectedErr: "invalid destination",
},
{
name: "missing dialect kind",
mergeArg: &MergeArgument{
PrimaryKeys: primaryKeys,
Columns: &cols,
SubQuery: "schema.tableName",
TableID: MockTableIdentifier{"schema.tableName"},
UppercaseEscNames: ptr.ToBool(false),
DestKind: constants.BigQuery,
},
expectedErr: "dialect cannot be nil",
},
{
name: "everything exists",
mergeArg: &MergeArgument{
Expand All @@ -99,6 +112,7 @@ func TestMergeArgument_Valid(t *testing.T) {
TableID: MockTableIdentifier{"schema.tableName"},
UppercaseEscNames: ptr.ToBool(false),
DestKind: constants.BigQuery,
Dialect: sql.BigQueryDialect{},
},
},
}
Expand Down
8 changes: 8 additions & 0 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
)

func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
// TODO: Switch all calls of [EscapeNameIfNecessary] to [EscapeNameIfNecessaryUsingDialect] and kill this.
var dialect = dialectFor(destKind, uppercaseEscNames)

if destKind != constants.S3 && dialect.NeedsEscaping(name) {
Expand All @@ -13,6 +14,13 @@ func EscapeNameIfNecessary(name string, uppercaseEscNames bool, destKind constan
return name
}

func EscapeNameIfNecessaryUsingDialect(name string, dialect Dialect) string {
if dialect.NeedsEscaping(name) {
return dialect.QuoteIdentifier(name)
}
return name
}

func dialectFor(destKind constants.DestinationKind, uppercaseEscNames bool) Dialect {
switch destKind {
case constants.BigQuery:
Expand Down

0 comments on commit 9026cbe

Please sign in to comment.