diff --git a/clients/shared/table.go b/clients/shared/table.go index 704caf802..a3bde8c41 100644 --- a/clients/shared/table.go +++ b/clients/shared/table.go @@ -35,7 +35,7 @@ func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc return nil } - sqlParts := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns) + sqlParts, addedCols := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns) for _, sqlPart := range sqlParts { slog.Info("[DDL] Executing query", slog.String("query", sqlPart)) if _, err := dwh.ExecContext(ctx, sqlPart); err != nil { @@ -45,6 +45,6 @@ func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc } } - tc.MutateInMemoryColumns(constants.Add, columns...) + tc.MutateInMemoryColumns(constants.Add, addedCols...) return nil } diff --git a/lib/destination/ddl/ddl.go b/lib/destination/ddl/ddl.go index 72e7ca6d8..caf20a477 100644 --- a/lib/destination/ddl/ddl.go +++ b/lib/destination/ddl/ddl.go @@ -69,18 +69,20 @@ func DropTemporaryTable(dwh destination.DataWarehouse, tableIdentifier sql.Table return nil } -func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, columns []columns.Column) []string { +func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, []columns.Column) { var parts []string - for _, col := range columns { + var addedCols []columns.Column + for _, col := range cols { if col.ShouldSkip() { continue } sqlPart := fmt.Sprintf("%s %s", dialect.QuoteIdentifier(col.Name()), dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey())) parts = append(parts, dialect.BuildAlterColumnQuery(tableID, constants.Add, sqlPart)) + addedCols = append(addedCols, col) } - return parts + return parts, addedCols } type AlterTableArgs struct { diff --git a/lib/destination/ddl/ddl_test.go b/lib/destination/ddl/ddl_test.go index 4d7e74a92..55e16cf67 100644 --- a/lib/destination/ddl/ddl_test.go +++ b/lib/destination/ddl/ddl_test.go @@ -79,40 +79,51 @@ func TestBuildCreateTableSQL(t *testing.T) { func TestBuildAlterTableAddColumns(t *testing.T) { { // No columns - sqlParts := BuildAlterTableAddColumns(nil, nil, []columns.Column{}) + sqlParts, addedCols := BuildAlterTableAddColumns(nil, nil, []columns.Column{}) assert.Empty(t, sqlParts) + assert.Empty(t, addedCols) } { // One column to add - sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{columns.NewColumn("dusty", typing.String)}) + col := columns.NewColumn("dusty", typing.String) + sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col}) assert.Len(t, sqlParts, 1) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0]) + + assert.Len(t, addedCols, 1) + assert.Equal(t, col, addedCols[0]) } { // Two columns, it skips the invalid column - sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), + col := columns.NewColumn("dusty", typing.String) + sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{ - columns.NewColumn("dusty", typing.String), + col, columns.NewColumn("invalid", typing.Invalid), }, ) assert.Len(t, sqlParts, 1) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0]) + + assert.Len(t, addedCols, 1) + assert.Equal(t, col, addedCols[0]) } { // Three columns to add - sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), - []columns.Column{ - columns.NewColumn("aussie", typing.String), - columns.NewColumn("doge", typing.String), - columns.NewColumn("age", typing.Integer), - }, - ) + col1 := columns.NewColumn("aussie", typing.String) + col2 := columns.NewColumn("doge", typing.String) + col3 := columns.NewColumn("age", typing.Integer) + sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3}) assert.Len(t, sqlParts, 3) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "aussie" VARCHAR(MAX)`, sqlParts[0]) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "doge" VARCHAR(MAX)`, sqlParts[1]) assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "age" INT8`, sqlParts[2]) + + assert.Len(t, addedCols, 3) + assert.Equal(t, col1, addedCols[0]) + assert.Equal(t, col2, addedCols[1]) + assert.Equal(t, col3, addedCols[2]) } }