Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Sep 30, 2024
1 parent a584128 commit 2f4b09c
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 15 deletions.
8 changes: 3 additions & 5 deletions clients/databricks/dialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ func (d DatabricksDialect) KindForDataType(_type string, _ string) (typing.KindD
return typing.Boolean, nil
case "VARIANT":
return typing.Struct, nil
case "TIMESTAMP":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil
}

return typing.KindDetails{}, fmt.Errorf("unsupported data type: %q", _type)
Expand All @@ -81,11 +83,7 @@ func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool {
}

func (d DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, temporary bool, colSQLParts []string) string {
temp := ""
if temporary {
temp = "TEMPORARY "
}
return fmt.Sprintf("CREATE %sTABLE %s (%s)", temp, tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", "))
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", "))
}

func (d DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string {
Expand Down
123 changes: 117 additions & 6 deletions clients/databricks/store.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package databricks

import (
"context"
"encoding/csv"
"fmt"
"log/slog"
"os"
"path/filepath"

_ "github.com/databricks/databricks-sql-go"
"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/values"

"github.com/artie-labs/transfer/clients/databricks/dialect"
"github.com/artie-labs/transfer/clients/shared"
Expand All @@ -13,6 +21,8 @@ import (
"github.com/artie-labs/transfer/lib/kafkalib"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/sql"
_ "github.com/databricks/databricks-sql-go"
driverctx "github.com/databricks/databricks-sql-go/driverctx"
)

type Store struct {
Expand All @@ -22,7 +32,12 @@ type Store struct {
}

func describeTableQuery(tableID TableIdentifier) (string, []any) {
return fmt.Sprintf("DESCRIBE TABLE %s.%s.%s", tableID.Database(), tableID.Schema(), tableID.Table()), nil
_dialect := dialect.DatabricksDialect{}
return fmt.Sprintf("DESCRIBE TABLE %s.%s.%s",
_dialect.QuoteIdentifier(tableID.Database()),
_dialect.QuoteIdentifier(tableID.Schema()),
_dialect.QuoteIdentifier(tableID.Table()),
), nil
}

func (s Store) Merge(tableData *optimization.TableData) error {
Expand Down Expand Up @@ -54,15 +69,111 @@ func (s Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTabl
ConfigMap: s.configMap,
Query: query,
Args: args,
ColumnNameForName: "column_name",
ColumnNameForName: "col_name",
ColumnNameForDataType: "data_type",
ColumnNameForComment: "description",
ColumnNameForComment: "comment",
DropDeletedColumns: tableData.TopicConfig().DropDeletedColumns,
}.GetTableConfig()
}

func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, parentTableID sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
panic("not implemented")
func (s Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, _ sql.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dialect: s.Dialect(),
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}

if err := tempAlterTableArgs.AlterTable(s, tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}

// Write data into a temporary file
fp, err := s.writeTemporaryTableFile(tableData, tempTableID)
if err != nil {
return fmt.Errorf("failed to load temporary table: %w", err)
}

defer func() {
// In the case where PUT or COPY fails, we'll at least delete the temporary file.
if deleteErr := os.RemoveAll(fp); deleteErr != nil {
slog.Warn("Failed to delete temp file", slog.Any("err", deleteErr), slog.String("filePath", fp))
}
}()

castedTempTableID, isOk := tempTableID.(TableIdentifier)
if !isOk {
return fmt.Errorf("failed to cast tempTableID to TableIdentifier")
}

dbfsFilePath := fmt.Sprintf("dbfs:/Volumes/%s/%s/vol_test/%s.csv", castedTempTableID.Database(), castedTempTableID.Schema(), tempTableID.Table())

ctx := driverctx.NewContextWithStagingInfo(context.Background(), []string{"/var"})

// Use the PUT INTO command to upload the file to Databricks
putCommand := fmt.Sprintf("PUT '%s' INTO '%s' OVERWRITE", fp, dbfsFilePath)
if _, err = s.ExecContext(ctx, putCommand); err != nil {
return fmt.Errorf("failed to run PUT INTO for temporary table: %w", err)
}

// Use the COPY INTO command to load the data into the temporary table
copyCommand := fmt.Sprintf("COPY INTO %s BY POSITION FROM '%s' FILEFORMAT = CSV FORMAT_OPTIONS ('delimiter' = '\t', 'header' = 'false')", tempTableID.FullyQualifiedName(), dbfsFilePath)
if _, err = s.ExecContext(ctx, copyCommand); err != nil {
return fmt.Errorf("failed to run COPY INTO for temporary table: %w", err)
}

return nil
}

func castColValStaging(colVal any, colKind typing.KindDetails) (string, error) {
if colVal == nil {
// \\N needs to match NULL_IF(...) from ddl.go
return `\\N`, nil
}

value, err := values.ToString(colVal, colKind)
if err != nil {
return "", err
}

return value, nil
}

func (s Store) writeTemporaryTableFile(tableData *optimization.TableData, newTableID sql.TableIdentifier) (string, error) {
fp := filepath.Join(os.TempDir(), fmt.Sprintf("%s.csv", newTableID.FullyQualifiedName()))
file, err := os.Create(fp)
if err != nil {
return "", err
}

defer file.Close()
writer := csv.NewWriter(file)
writer.Comma = '\t'

columns := tableData.ReadOnlyInMemoryCols().ValidColumns()
for _, value := range tableData.Rows() {
var row []string
for _, col := range columns {
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails)
if castErr != nil {
return "", castErr
}

row = append(row, castedValue)
}

if err = writer.Write(row); err != nil {
return "", fmt.Errorf("failed to write to csv: %w", err)
}
}

writer.Flush()
return fp, writer.Error()
}

func LoadStore(cfg config.Config) (Store, error) {
Expand Down
2 changes: 1 addition & 1 deletion clients/databricks/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier {
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf("%s.%s.%s", ti.database, ti.schema, ti.EscapedTable())
return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable())
}
1 change: 0 additions & 1 deletion lib/config/destination_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,4 @@ type Databricks struct {
Database string `yaml:"database"`
Protocol string `yaml:"protocol"`
CatalogName string `yaml:"catalogName"`
SchemaName string `yaml:"schemaName"`
}
1 change: 0 additions & 1 deletion lib/config/destinations.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ func (s Snowflake) ToConfig() (*gosnowflake.Config, error) {
func (d Databricks) DSN() string {
query := url.Values{}
query.Add("catalog", d.CatalogName)
query.Add("schema", d.SchemaName)
u := &url.URL{
Path: "/sql/1.0/warehouses/cab738c29ff77d72",
User: url.UserPassword("token", d.PersonalAccessToken),
Expand Down
26 changes: 25 additions & 1 deletion lib/db/db.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"database/sql"
"fmt"
"log/slog"
Expand All @@ -15,8 +16,9 @@ const (
)

type Store interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
Exec(query string, args ...any) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Begin() (*sql.Tx, error)
IsRetryableError(err error) bool
}
Expand All @@ -25,6 +27,28 @@ type storeWrapper struct {
*sql.DB
}

func (s *storeWrapper) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
var result sql.Result
var err error
for attempts := 0; attempts < maxAttempts; attempts++ {
if attempts > 0 {
sleepDuration := jitter.Jitter(sleepBaseMs, jitter.DefaultMaxMs, attempts-1)
slog.Warn("Failed to execute the query, retrying...",
slog.Any("err", err),
slog.Duration("sleep", sleepDuration),
slog.Int("attempts", attempts),
)
time.Sleep(sleepDuration)
}

result, err = s.DB.ExecContext(ctx, query, args...)
if err == nil || !s.IsRetryableError(err) {
break
}
}
return result, err
}

func (s *storeWrapper) Exec(query string, args ...any) (sql.Result, error) {
var result sql.Result
var err error
Expand Down

0 comments on commit 2f4b09c

Please sign in to comment.