From bf2dfe5f8416715e4a0668d5118c27f9208c0b64 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Fri, 3 May 2024 15:23:39 -0700 Subject: [PATCH] [dml] Pass `[]columns.Column` in `MergeArgument` (#558) --- clients/shared/merge.go | 2 +- lib/destination/dml/merge.go | 40 ++++++++++------------ lib/destination/dml/merge_bigquery_test.go | 4 +-- lib/destination/dml/merge_mssql_test.go | 2 +- lib/destination/dml/merge_parts_test.go | 14 ++++---- lib/destination/dml/merge_test.go | 10 +++--- lib/destination/dml/merge_valid_test.go | 33 ++++++++++++------ 7 files changed, 57 insertions(+), 48 deletions(-) diff --git a/clients/shared/merge.go b/clients/shared/merge.go index 4c462193e..55062a16a 100644 --- a/clients/shared/merge.go +++ b/clients/shared/merge.go @@ -135,7 +135,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg SubQuery: subQuery, IdempotentKey: tableData.TopicConfig().IdempotentKey, PrimaryKeys: primaryKeys, - Columns: cols, + Columns: cols.ValidColumns(), SoftDelete: tableData.TopicConfig().SoftDelete, DestKind: dwh.Label(), Dialect: dwh.Dialect(), diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index aa5e58038..271ee33e1 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -24,7 +24,7 @@ type MergeArgument struct { AdditionalEqualityStrings []string // Columns will need to be escaped - Columns *columns.Columns + Columns []columns.Column DestKind constants.DestinationKind SoftDelete bool @@ -43,9 +43,14 @@ func (m *MergeArgument) Valid() error { return fmt.Errorf("merge argument does not contain primary keys") } - if len(m.Columns.ValidColumns()) == 0 { + if len(m.Columns) == 0 { return fmt.Errorf("columns cannot be empty") } + for _, column := range m.Columns { + if column.ShouldSkip() { + return fmt.Errorf("column %q is invalid and should be skipped", column.Name()) + } + } if m.TableID == nil { return fmt.Errorf("tableID cannot be nil") @@ -123,16 +128,14 @@ func (m *MergeArgument) GetParts() ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - columns := m.Columns.ValidColumns() - if m.SoftDelete { return []string{ // INSERT - m.buildInsertQuery(columns, equalitySQLParts), + m.buildInsertQuery(m.Columns, equalitySQLParts), // UPDATE fmt.Sprintf(`UPDATE %s as c SET %s FROM %s as cc WHERE %s%s;`, // UPDATE table set col1 = cc. col1 - m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(columns, m.Dialect), + m.TableID.FullyQualifiedName(), buildColumnsUpdateFragment(m.Columns, m.Dialect), // FROM table (temp) WHERE join on PK(s) m.SubQuery, strings.Join(equalitySQLParts, " and "), idempotentClause, ), @@ -140,8 +143,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { } // We also need to remove __artie flags since it does not exist in the destination table - var removed bool - columns, removed = removeDeleteColumnMarker(columns) + columns, removed := removeDeleteColumnMarker(m.Columns) if !removed { return nil, errors.New("artie delete flag doesn't exist") } @@ -222,8 +224,6 @@ func (m *MergeArgument) GetStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...) } - columns := m.Columns.ValidColumns() - if m.SoftDelete { return fmt.Sprintf(` MERGE INTO %s c USING %s AS cc ON %s @@ -231,19 +231,18 @@ WHEN MATCHED %sTHEN UPDATE SET %s WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);`, m.TableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, buildColumnsUpdateFragment(columns, m.Dialect), + idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect), // Insert - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(columns, m.Dialect), ","), + m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(m.Columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: quoteColumns(columns, m.Dialect), + Vals: quoteColumns(m.Columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil } // We also need to remove __artie flags since it does not exist in the destination table - var removed bool - columns, removed = removeDeleteColumnMarker(columns) + columns, removed := removeDeleteColumnMarker(m.Columns) if !removed { return "", errors.New("artie delete flag doesn't exist") } @@ -285,8 +284,6 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - columns := m.Columns.ValidColumns() - if m.SoftDelete { return fmt.Sprintf(` MERGE INTO %s c @@ -295,19 +292,18 @@ WHEN MATCHED %sTHEN UPDATE SET %s WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "), // Update + Soft Deletion - idempotentClause, buildColumnsUpdateFragment(columns, m.Dialect), + idempotentClause, buildColumnsUpdateFragment(m.Columns, m.Dialect), // Insert - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(columns, m.Dialect), ","), + m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(quoteColumns(m.Columns, m.Dialect), ","), array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: quoteColumns(columns, m.Dialect), + Vals: quoteColumns(m.Columns, m.Dialect), Separator: ",", Prefix: "cc.", })), nil } // We also need to remove __artie flags since it does not exist in the destination table - var removed bool - columns, removed = removeDeleteColumnMarker(columns) + columns, removed := removeDeleteColumnMarker(m.Columns) if !removed { return "", errors.New("artie delete flag doesn't exist") } diff --git a/lib/destination/dml/merge_bigquery_test.go b/lib/destination/dml/merge_bigquery_test.go index 7d031e6f3..1abeac2c0 100644 --- a/lib/destination/dml/merge_bigquery_test.go +++ b/lib/destination/dml/merge_bigquery_test.go @@ -20,7 +20,7 @@ func TestMergeStatement_TempTable(t *testing.T) { TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", PrimaryKeys: []columns.Column{columns.NewColumn("order_id", typing.Invalid)}, - Columns: &cols, + Columns: cols.ValidColumns(), DestKind: constants.BigQuery, Dialect: sql.BigQueryDialect{}, SoftDelete: false, @@ -43,7 +43,7 @@ func TestMergeStatement_JSONKey(t *testing.T) { TableID: MockTableIdentifier{"customers.orders"}, SubQuery: "customers.orders_tmp", PrimaryKeys: []columns.Column{orderOIDCol}, - Columns: &cols, + Columns: cols.ValidColumns(), DestKind: constants.BigQuery, Dialect: sql.BigQueryDialect{}, SoftDelete: false, diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index 15613eb7f..67d5137e4 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -45,7 +45,7 @@ func Test_GetMSSQLStatement(t *testing.T) { SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Column{columns.NewColumn("id", typing.Invalid)}, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.MSSQL, Dialect: sql.MSSQLDialect{}, SoftDelete: false, diff --git a/lib/destination/dml/merge_parts_test.go b/lib/destination/dml/merge_parts_test.go index de08ebbc3..22a35bbfd 100644 --- a/lib/destination/dml/merge_parts_test.go +++ b/lib/destination/dml/merge_parts_test.go @@ -71,7 +71,7 @@ func TestMergeStatementParts_SkipDelete(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(false), @@ -98,7 +98,7 @@ func TestMergeStatementPartsSoftDelete(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, SoftDelete: true, @@ -138,7 +138,7 @@ func TestMergeStatementPartsSoftDeleteComposite(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, SoftDelete: true, @@ -181,7 +181,7 @@ func TestMergeStatementParts(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), @@ -207,7 +207,7 @@ func TestMergeStatementParts(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, IdempotentKey: "created_at", @@ -239,7 +239,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), @@ -265,7 +265,7 @@ func TestMergeStatementPartsCompositeKey(t *testing.T) { TableID: MockTableIdentifier{fqTableName}, SubQuery: tempTableName, PrimaryKeys: res.PrimaryKeys, - Columns: &res.ColumnsToTypes, + Columns: res.ColumnsToTypes.ValidColumns(), DestKind: constants.Redshift, Dialect: sql.RedshiftDialect{}, ContainsHardDeletes: ptr.ToBool(true), diff --git a/lib/destination/dml/merge_test.go b/lib/destination/dml/merge_test.go index 5a76abbfe..3836e7282 100644 --- a/lib/destination/dml/merge_test.go +++ b/lib/destination/dml/merge_test.go @@ -100,7 +100,7 @@ func TestMergeStatementSoftDelete(t *testing.T) { SubQuery: subQuery, IdempotentKey: idempotentKey, PrimaryKeys: []columns.Column{columns.NewColumn("id", typing.Invalid)}, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.Snowflake, Dialect: sql.SnowflakeDialect{}, SoftDelete: true, @@ -149,7 +149,7 @@ func TestMergeStatement(t *testing.T) { SubQuery: subQuery, IdempotentKey: "", PrimaryKeys: []columns.Column{columns.NewColumn("id", typing.Invalid)}, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.Snowflake, Dialect: sql.SnowflakeDialect{}, SoftDelete: false, @@ -197,7 +197,7 @@ func TestMergeStatementIdempotentKey(t *testing.T) { SubQuery: subQuery, IdempotentKey: "updated_at", PrimaryKeys: []columns.Column{columns.NewColumn("id", typing.Invalid)}, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.Snowflake, Dialect: sql.SnowflakeDialect{}, SoftDelete: false, @@ -242,7 +242,7 @@ func TestMergeStatementCompositeKey(t *testing.T) { columns.NewColumn("id", typing.Invalid), columns.NewColumn("another_id", typing.Invalid), }, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.Snowflake, Dialect: sql.SnowflakeDialect{}, SoftDelete: false, @@ -291,7 +291,7 @@ func TestMergeStatementEscapePrimaryKeys(t *testing.T) { columns.NewColumn("id", typing.Invalid), columns.NewColumn("group", typing.Invalid), }, - Columns: &_cols, + Columns: _cols.ValidColumns(), DestKind: constants.Snowflake, Dialect: sql.SnowflakeDialect{}, SoftDelete: false, diff --git a/lib/destination/dml/merge_valid_test.go b/lib/destination/dml/merge_valid_test.go index fc7114762..358f2fc45 100644 --- a/lib/destination/dml/merge_valid_test.go +++ b/lib/destination/dml/merge_valid_test.go @@ -17,10 +17,11 @@ func TestMergeArgument_Valid(t *testing.T) { columns.NewColumn("id", typing.Integer), } - var cols columns.Columns - cols.AddColumn(columns.NewColumn("id", typing.Integer)) - cols.AddColumn(columns.NewColumn("firstName", typing.String)) - cols.AddColumn(columns.NewColumn("lastName", typing.String)) + cols := []columns.Column{ + columns.NewColumn("id", typing.Integer), + columns.NewColumn("firstName", typing.String), + columns.NewColumn("lastName", typing.String), + } testCases := []struct { name string @@ -47,7 +48,7 @@ func TestMergeArgument_Valid(t *testing.T) { name: "pks, cols, colsTpTypes exists but no subquery or tableID", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, }, expectedErr: "tableID cannot be nil", }, @@ -55,7 +56,7 @@ func TestMergeArgument_Valid(t *testing.T) { name: "pks, cols, colsTpTypes, subquery exists but no tableID", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, SubQuery: "schema.tableName", }, expectedErr: "tableID cannot be nil", @@ -64,7 +65,7 @@ func TestMergeArgument_Valid(t *testing.T) { name: "pks, cols, colsTpTypes, tableID exists but no subquery", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, TableID: MockTableIdentifier{"schema.tableName"}, }, expectedErr: "subQuery cannot be empty", @@ -73,7 +74,7 @@ func TestMergeArgument_Valid(t *testing.T) { name: "missing dest kind", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, SubQuery: "schema.tableName", TableID: MockTableIdentifier{"schema.tableName"}, }, @@ -83,7 +84,7 @@ func TestMergeArgument_Valid(t *testing.T) { name: "missing dialect kind", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, SubQuery: "schema.tableName", TableID: MockTableIdentifier{"schema.tableName"}, DestKind: constants.BigQuery, @@ -94,12 +95,24 @@ func TestMergeArgument_Valid(t *testing.T) { name: "everything exists", mergeArg: &MergeArgument{ PrimaryKeys: primaryKeys, - Columns: &cols, + Columns: cols, + SubQuery: "schema.tableName", + TableID: MockTableIdentifier{"schema.tableName"}, + DestKind: constants.BigQuery, + Dialect: sql.BigQueryDialect{}, + }, + }, + { + name: "invalid column", + mergeArg: &MergeArgument{ + PrimaryKeys: primaryKeys, + Columns: []columns.Column{columns.NewColumn("id", typing.Invalid)}, SubQuery: "schema.tableName", TableID: MockTableIdentifier{"schema.tableName"}, DestKind: constants.BigQuery, Dialect: sql.BigQueryDialect{}, }, + expectedErr: `column "id" is invalid and should be skipped`, }, }