Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Nov 15, 2024
1 parent ac064ce commit 1dde15c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 16 deletions.
4 changes: 2 additions & 2 deletions clients/shared/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc
return nil
}

sqlParts := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns)
sqlParts, addedCols := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns)
for _, sqlPart := range sqlParts {
slog.Info("[DDL] Executing query", slog.String("query", sqlPart))
if _, err := dwh.ExecContext(ctx, sqlPart); err != nil {
Expand All @@ -45,6 +45,6 @@ func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc
}
}

tc.MutateInMemoryColumns(constants.Add, columns...)
tc.MutateInMemoryColumns(constants.Add, addedCols...)
return nil
}
8 changes: 5 additions & 3 deletions lib/destination/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,20 @@ func DropTemporaryTable(dwh destination.DataWarehouse, tableIdentifier sql.Table
return nil
}

func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, columns []columns.Column) []string {
func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, []columns.Column) {
var parts []string
for _, col := range columns {
var addedCols []columns.Column
for _, col := range cols {
if col.ShouldSkip() {
continue
}

sqlPart := fmt.Sprintf("%s %s", dialect.QuoteIdentifier(col.Name()), dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey()))
parts = append(parts, dialect.BuildAlterColumnQuery(tableID, constants.Add, sqlPart))
addedCols = append(addedCols, col)
}

return parts
return parts, addedCols
}

type AlterTableArgs struct {
Expand Down
33 changes: 22 additions & 11 deletions lib/destination/ddl/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,40 +79,51 @@ func TestBuildCreateTableSQL(t *testing.T) {
func TestBuildAlterTableAddColumns(t *testing.T) {
{
// No columns
sqlParts := BuildAlterTableAddColumns(nil, nil, []columns.Column{})
sqlParts, addedCols := BuildAlterTableAddColumns(nil, nil, []columns.Column{})
assert.Empty(t, sqlParts)
assert.Empty(t, addedCols)
}
{
// One column to add
sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{columns.NewColumn("dusty", typing.String)})
col := columns.NewColumn("dusty", typing.String)
sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col})
assert.Len(t, sqlParts, 1)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0])

assert.Len(t, addedCols, 1)
assert.Equal(t, col, addedCols[0])
}
{
// Two columns, it skips the invalid column
sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"),
col := columns.NewColumn("dusty", typing.String)
sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"),
[]columns.Column{
columns.NewColumn("dusty", typing.String),
col,
columns.NewColumn("invalid", typing.Invalid),
},
)
assert.Len(t, sqlParts, 1)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0])

assert.Len(t, addedCols, 1)
assert.Equal(t, col, addedCols[0])
}
{
// Three columns to add
sqlParts := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"),
[]columns.Column{
columns.NewColumn("aussie", typing.String),
columns.NewColumn("doge", typing.String),
columns.NewColumn("age", typing.Integer),
},
)
col1 := columns.NewColumn("aussie", typing.String)
col2 := columns.NewColumn("doge", typing.String)
col3 := columns.NewColumn("age", typing.Integer)

sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3})
assert.Len(t, sqlParts, 3)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "aussie" VARCHAR(MAX)`, sqlParts[0])
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "doge" VARCHAR(MAX)`, sqlParts[1])
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "age" INT8`, sqlParts[2])

assert.Len(t, addedCols, 3)
assert.Equal(t, col1, addedCols[0])
assert.Equal(t, col2, addedCols[1])
assert.Equal(t, col3, addedCols[2])
}
}

Expand Down

0 comments on commit 1dde15c

Please sign in to comment.