Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Nov 15, 2024
1 parent 1dde15c commit 01d70db
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
23 changes: 18 additions & 5 deletions clients/shared/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,34 @@ func CreateTable(ctx context.Context, dwh destination.DataWarehouse, tableData *
return nil
}

func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc *types.DwhTableConfig, tableID sql.TableIdentifier, columns []columns.Column) error {
if len(columns) == 0 {
func AlterTableAddColumns(ctx context.Context, dwh destination.DataWarehouse, tc *types.DwhTableConfig, tableID sql.TableIdentifier, cols []columns.Column) error {
if len(cols) == 0 {
return nil
}

sqlParts, addedCols := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, columns)
var colsToAdd []columns.Column
for _, col := range cols {
if col.ShouldSkip() {
continue
}

colsToAdd = append(colsToAdd, col)
}

sqlParts, err := ddl.BuildAlterTableAddColumns(dwh.Dialect(), tableID, colsToAdd)
if err != nil {
return fmt.Errorf("failed to build alter table add columns: %w", err)
}

for _, sqlPart := range sqlParts {
slog.Info("[DDL] Executing query", slog.String("query", sqlPart))
if _, err := dwh.ExecContext(ctx, sqlPart); err != nil {
if _, err = dwh.ExecContext(ctx, sqlPart); err != nil {
if !dwh.Dialect().IsColumnAlreadyExistsErr(err) {
return fmt.Errorf("failed to alter table: %w", err)
}
}
}

tc.MutateInMemoryColumns(constants.Add, addedCols...)
tc.MutateInMemoryColumns(constants.Add, colsToAdd...)
return nil
}
8 changes: 3 additions & 5 deletions lib/destination/ddl/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,18 @@ func DropTemporaryTable(dwh destination.DataWarehouse, tableIdentifier sql.Table
return nil
}

func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, []columns.Column) {
func BuildAlterTableAddColumns(dialect sql.Dialect, tableID sql.TableIdentifier, cols []columns.Column) ([]string, error) {
var parts []string
var addedCols []columns.Column
for _, col := range cols {
if col.ShouldSkip() {
continue
return nil, fmt.Errorf("received an invalid column %q", col.Name())
}

sqlPart := fmt.Sprintf("%s %s", dialect.QuoteIdentifier(col.Name()), dialect.DataTypeForKind(col.KindDetails, col.PrimaryKey()))
parts = append(parts, dialect.BuildAlterColumnQuery(tableID, constants.Add, sqlPart))
addedCols = append(addedCols, col)
}

return parts, addedCols
return parts, nil
}

type AlterTableArgs struct {
Expand Down
27 changes: 9 additions & 18 deletions lib/destination/ddl/ddl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,51 +79,42 @@ func TestBuildCreateTableSQL(t *testing.T) {
func TestBuildAlterTableAddColumns(t *testing.T) {
{
// No columns
sqlParts, addedCols := BuildAlterTableAddColumns(nil, nil, []columns.Column{})
sqlParts, err := BuildAlterTableAddColumns(nil, nil, []columns.Column{})
assert.NoError(t, err)
assert.Empty(t, sqlParts)
assert.Empty(t, addedCols)
}
{
// One column to add
col := columns.NewColumn("dusty", typing.String)
sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col})
sqlParts, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col})
assert.NoError(t, err)
assert.Len(t, sqlParts, 1)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0])

assert.Len(t, addedCols, 1)
assert.Equal(t, col, addedCols[0])
}
{
// Two columns, it skips the invalid column
// Two columns, one invalid, it will error.
col := columns.NewColumn("dusty", typing.String)
sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"),
_, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"),
[]columns.Column{
col,
columns.NewColumn("invalid", typing.Invalid),
},
)
assert.Len(t, sqlParts, 1)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "dusty" VARCHAR(MAX)`, sqlParts[0])

assert.Len(t, addedCols, 1)
assert.Equal(t, col, addedCols[0])
assert.ErrorContains(t, err, `received an invalid column "invalid"`)
}
{
// Three columns to add
col1 := columns.NewColumn("aussie", typing.String)
col2 := columns.NewColumn("doge", typing.String)
col3 := columns.NewColumn("age", typing.Integer)

sqlParts, addedCols := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3})
sqlParts, err := BuildAlterTableAddColumns(dialect.RedshiftDialect{}, dialect.NewTableIdentifier("schema", "table"), []columns.Column{col1, col2, col3})
assert.NoError(t, err)
assert.Len(t, sqlParts, 3)
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "aussie" VARCHAR(MAX)`, sqlParts[0])
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "doge" VARCHAR(MAX)`, sqlParts[1])
assert.Equal(t, `ALTER TABLE schema."table" add COLUMN "age" INT8`, sqlParts[2])

assert.Len(t, addedCols, 3)
assert.Equal(t, col1, addedCols[0])
assert.Equal(t, col2, addedCols[1])
assert.Equal(t, col3, addedCols[2])
}
}

Expand Down

0 comments on commit 01d70db

Please sign in to comment.