Skip to content

Commit

Permalink
Merge branch 'master' into snowflake-ddl
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Nov 20, 2024
2 parents 63fb096 + 01338f7 commit 7536bf8
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s *Store) Append(ctx context.Context, tableData *optimization.TableData, u

func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, dwh *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, opts types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
if err := shared.CreateTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID, true); err != nil {
if err := shared.CreateTempTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID); err != nil {
return err
}
}
Expand Down
13 changes: 4 additions & 9 deletions clients/databricks/dialect/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
)

Expand All @@ -13,18 +12,14 @@ func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bo
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", "))
}

func (d DatabricksDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string {
return d.buildAlterColumnQuery(tableID, constants.Add, sqlPart)
func (DatabricksDialect) BuildAddColumnQuery(tableID sql.TableIdentifier, sqlPart string) string {
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s", tableID.FullyQualifiedName(), sqlPart)
}

func (d DatabricksDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string {
return d.buildAlterColumnQuery(tableID, constants.Delete, colName)
func (DatabricksDialect) BuildDropColumnQuery(tableID sql.TableIdentifier, colName string) string {
return fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", tableID.FullyQualifiedName(), colName)
}

func (DatabricksDialect) BuildDescribeTableQuery(tableID sql.TableIdentifier) (string, []any, error) {
return fmt.Sprintf("DESCRIBE TABLE %s", tableID.FullyQualifiedName()), nil, nil
}

func (DatabricksDialect) buildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string {
return fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", tableID.FullyQualifiedName(), columnOp, colSQLPart)
}
4 changes: 2 additions & 2 deletions clients/databricks/dialect/dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ func TestDatabricksDialect_BuildCreateTableQuery(t *testing.T) {
func TestDatabricksDialect_BuildAddColumnQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")
assert.Equal(t, "ALTER TABLE {TABLE} add COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAddColumnQuery(fakeTableID, "{SQL_PART} {DATA_TYPE}"))
assert.Equal(t, "ALTER TABLE {TABLE} ADD COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAddColumnQuery(fakeTableID, "{SQL_PART} {DATA_TYPE}"))
}

func TestDatabricksDialect_BuildDropColumnQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")
assert.Equal(t, "ALTER TABLE {TABLE} drop COLUMN {SQL_PART}", DatabricksDialect{}.BuildDropColumnQuery(fakeTableID, "{SQL_PART}"))
assert.Equal(t, "ALTER TABLE {TABLE} DROP COLUMN {SQL_PART}", DatabricksDialect{}.BuildDropColumnQuery(fakeTableID, "{SQL_PART}"))
}

func TestDatabricksDialect_BuildDedupeQueries(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion clients/databricks/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTabl

func (s Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, dwh *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, opts types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
if err := shared.CreateTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID, true); err != nil {
if err := shared.CreateTempTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, dwh *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, opts types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
if err := shared.CreateTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID, true); err != nil {
if err := shared.CreateTempTable(ctx, s, tableData, dwh, opts.ColumnSettings, tempTableID); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimizati
}

if createTempTable {
if err = shared.CreateTable(ctx, s, tableData, tableConfig, opts.ColumnSettings, tempTableID, true); err != nil {
if err = shared.CreateTempTable(ctx, s, tableData, tableConfig, opts.ColumnSettings, tempTableID); err != nil {
return err
}
}
Expand Down
2 changes: 1 addition & 1 deletion clients/shared/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Append(ctx context.Context, dwh destination.DataWarehouse, tableData *optim

tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
if tableConfig.CreateTable() {
if err = CreateTable(ctx, dwh, tableData, tableConfig, opts.ColumnSettings, tableID, false); err != nil {
if err = CreateTable(ctx, dwh, tableData, tableConfig, opts.ColumnSettings, tableID, false, targetKeysMissing); err != nil {
return fmt.Errorf("failed to create table: %w", err)
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func Merge(ctx context.Context, dwh destination.DataWarehouse, tableData *optimi

tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
if tableConfig.CreateTable() {
if err = CreateTable(ctx, dwh, tableData, tableConfig, opts.ColumnSettings, tableID, false); err != nil {
if err = CreateTable(ctx, dwh, tableData, tableConfig, opts.ColumnSettings, tableID, false, targetKeysMissing); err != nil {
return fmt.Errorf("failed to create table: %w", err)
}
} else {
Expand Down
10 changes: 7 additions & 3 deletions clients/shared/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func CreateTable(ctx context.Context, dwh destination.DataWarehouse, tableData *optimization.TableData, tc *types.DwhTableConfig, settings config.SharedDestinationColumnSettings, tableID sql.TableIdentifier, tempTable bool) error {
query, err := ddl.BuildCreateTableSQL(settings, dwh.Dialect(), tableID, tempTable, tableData.Mode(), tableData.ReadOnlyInMemoryCols().GetColumns())
func CreateTempTable(ctx context.Context, dwh destination.DataWarehouse, tableData *optimization.TableData, tc *types.DwhTableConfig, settings config.SharedDestinationColumnSettings, tableID sql.TableIdentifier) error {
return CreateTable(ctx, dwh, tableData, tc, settings, tableID, true, tableData.ReadOnlyInMemoryCols().GetColumns())
}

func CreateTable(ctx context.Context, dwh destination.DataWarehouse, tableData *optimization.TableData, tc *types.DwhTableConfig, settings config.SharedDestinationColumnSettings, tableID sql.TableIdentifier, tempTable bool, cols []columns.Column) error {
query, err := ddl.BuildCreateTableSQL(settings, dwh.Dialect(), tableID, tempTable, tableData.Mode(), cols)
if err != nil {
return fmt.Errorf("failed to build create table sql: %w", err)
}
Expand All @@ -28,7 +32,7 @@ func CreateTable(ctx context.Context, dwh destination.DataWarehouse, tableData *
}

// Update cache with the new columns that we've added.
tc.MutateInMemoryColumns(constants.Add, tableData.ReadOnlyInMemoryCols().GetColumns()...)
tc.MutateInMemoryColumns(constants.Add, cols...)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {

func (s *Store) PrepareTemporaryTable(ctx context.Context, tableData *optimization.TableData, dwh *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
if err := shared.CreateTable(ctx, s, tableData, dwh, additionalSettings.ColumnSettings, tempTableID, true); err != nil {
if err := shared.CreateTempTable(ctx, s, tableData, dwh, additionalSettings.ColumnSettings, tempTableID); err != nil {
return err
}
}
Expand Down

0 comments on commit 7536bf8

Please sign in to comment.