Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Oct 3, 2024
1 parent 632a0c1 commit f8d7862
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
29 changes: 13 additions & 16 deletions clients/databricks/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,26 @@ func (d DatabricksDialect) BuildDedupeQueries(tableID, stagingTableID sql.TableI
orderByCols = append(orderByCols, fmt.Sprintf("%s ASC", pk))
}

var parts []string
parts = append(parts, fmt.Sprintf("CREATE OR REPLACE TEMP VIEW %s AS SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) as row_num FROM %s) WHERE row_num = 2",
stagingTableID.FullyQualifiedName(),
strings.Join(primaryKeysEscaped, ", "),
strings.Join(orderByCols, ", "),
tableID.FullyQualifiedName(),
))
tempViewQuery := fmt.Sprintf(`
CREATE TABLE %s AS
SELECT *
FROM %s
QUALIFY ROW_NUMBER() OVER (PARTITION BY %s ORDER BY %s) = 2
`, stagingTableID.FullyQualifiedName(), tableID.FullyQualifiedName(), strings.Join(primaryKeysEscaped, ", "), strings.Join(orderByCols, ", "))

var whereClauses []string
for _, primaryKeyEscaped := range primaryKeysEscaped {
whereClauses = append(whereClauses, fmt.Sprintf("t1.%s = t2.%s", primaryKeyEscaped, primaryKeyEscaped))
}

parts = append(parts,
fmt.Sprintf("DELETE FROM %s t1 USING %s t2 WHERE %s",
tableID.FullyQualifiedName(),
stagingTableID.FullyQualifiedName(),
strings.Join(whereClauses, " AND "),
),
deleteQuery := fmt.Sprintf("DELETE FROM %s t1 WHERE EXISTS (SELECT * FROM %s t2 WHERE %s)",
tableID.FullyQualifiedName(),
stagingTableID.FullyQualifiedName(),
strings.Join(whereClauses, " AND "),
)

parts = append(parts, fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName()))
return parts
// Insert deduplicated rows back into the original table
insertQuery := fmt.Sprintf("INSERT INTO %s SELECT * FROM %s", tableID.FullyQualifiedName(), stagingTableID.FullyQualifiedName())
return []string{tempViewQuery, deleteQuery, insertQuery}
}

func (d DatabricksDialect) BuildMergeQueries(
Expand Down
18 changes: 14 additions & 4 deletions clients/databricks/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"os"
"path/filepath"

"github.com/artie-labs/transfer/lib/destination"

_ "github.com/databricks/databricks-sql-go"
"github.com/databricks/databricks-sql-go/driverctx"

Expand Down Expand Up @@ -56,8 +54,19 @@ func (s Store) Dialect() sql.Dialect {

func (s Store) Dedupe(tableID sql.TableIdentifier, primaryKeys []string, includeArtieUpdatedAt bool) error {
stagingTableID := shared.TempTableID(tableID)
dedupeQueries := s.Dialect().BuildDedupeQueries(tableID, stagingTableID, primaryKeys, includeArtieUpdatedAt)
return destination.ExecStatements(s, dedupeQueries)
defer func() {
// Drop the temporary table once we're done with the dedupe.
_ = ddl.DropTemporaryTable(s, stagingTableID, false)
}()

for _, query := range s.Dialect().BuildDedupeQueries(tableID, stagingTableID, primaryKeys, includeArtieUpdatedAt) {
// Databricks doesn't support transactions, so we can't wrap this in a transaction.
if _, err := s.Exec(query); err != nil {
return fmt.Errorf("failed to execute query: %w", err)
}
}

return nil
}

func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) {
Expand Down Expand Up @@ -180,6 +189,7 @@ func LoadStore(cfg config.Config) (Store, error) {
if err != nil {
return Store{}, err
}

return Store{
Store: store,
cfg: cfg,
Expand Down

0 comments on commit f8d7862

Please sign in to comment.