Skip to content

Commit

Permalink
Supporting Databricks - Part Two (#939)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 authored Oct 3, 2024
1 parent 57af576 commit c6a2206
Show file tree
Hide file tree
Showing 7 changed files with 459 additions and 0 deletions.
114 changes: 114 additions & 0 deletions clients/databricks/dialect/dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package dialect

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

type DatabricksDialect struct{}

func (DatabricksDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf("`%s`", identifier)
}

func (DatabricksDialect) EscapeStruct(value string) string {
panic("not implemented")
}

func (DatabricksDialect) IsColumnAlreadyExistsErr(err error) bool {
return err != nil && strings.Contains(err.Error(), "[FIELDS_ALREADY_EXISTS]")
}

func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool {
return err != nil && strings.Contains(err.Error(), "[TABLE_OR_VIEW_NOT_FOUND]")
}

func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string {
// Databricks doesn't have a concept of temporary tables.
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", "))
}

func (DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string {
return fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", tableID.FullyQualifiedName(), columnOp, colSQLPart)
}

func (DatabricksDialect) BuildIsNotToastValueExpression(tableAlias constants.TableAlias, column columns.Column) string {
panic("not implemented")
}

func (DatabricksDialect) BuildDedupeTableQuery(tableID sql.TableIdentifier, primaryKeys []string) string {
panic("not implemented")
}

func (DatabricksDialect) BuildDedupeQueries(_, _ sql.TableIdentifier, _ []string, _ bool) []string {
panic("not implemented")
}

func (d DatabricksDialect) BuildMergeQueries(
tableID sql.TableIdentifier,
subQuery string,
primaryKeys []columns.Column,
additionalEqualityStrings []string,
cols []columns.Column,
softDelete bool,
_ bool,
) ([]string, error) {
// TODO: Add tests.

// Build the base equality condition for the MERGE query
equalitySQLParts := sql.BuildColumnComparisons(primaryKeys, constants.TargetAlias, constants.StagingAlias, sql.Equal, d)
if len(additionalEqualityStrings) > 0 {
equalitySQLParts = append(equalitySQLParts, additionalEqualityStrings...)
}

// Construct the base MERGE query
baseQuery := fmt.Sprintf(`MERGE INTO %s %s USING %s %s ON %s`, tableID.FullyQualifiedName(), constants.TargetAlias, subQuery, constants.StagingAlias, strings.Join(equalitySQLParts, " AND "))
// Remove columns with only the delete marker, as they are handled separately
cols, err := columns.RemoveOnlySetDeleteColumnMarker(cols)
if err != nil {
return nil, err
}

if softDelete {
// If softDelete is enabled, handle both update and soft-delete logic
return []string{baseQuery + fmt.Sprintf(`
WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s
WHEN MATCHED AND IFNULL(%s, false) = true THEN UPDATE SET %s
WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s);`,
sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d),
sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment([]columns.Column{columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)}, constants.StagingAlias, constants.TargetAlias, d),
strings.Join(sql.QuoteColumns(cols, d), ","),
strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","),
)}, nil
}

// Remove the delete marker for hard-delete logic
cols, err = columns.RemoveDeleteColumnMarker(cols)
if err != nil {
return nil, err
}

// Handle the case where hard-deletes are included
return []string{baseQuery + fmt.Sprintf(`
WHEN MATCHED AND %s THEN DELETE
WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s
WHEN NOT MATCHED AND IFNULL(%s, false) = false THEN INSERT (%s) VALUES (%s);`,
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d),
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
strings.Join(sql.QuoteColumns(cols, d), ","),
strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","),
)}, nil
}

func (d DatabricksDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy {
return sql.Native
}
80 changes: 80 additions & 0 deletions clients/databricks/dialect/dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package dialect

import (
"fmt"
"testing"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/mocks"
"github.com/stretchr/testify/assert"
)

func TestDatabricksDialect_QuoteIdentifier(t *testing.T) {
dialect := DatabricksDialect{}
assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo"))
assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO"))
}

func TestDatabricksDialect_IsColumnAlreadyExistsErr(t *testing.T) {
{
// No error
assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(nil))
}
{
// Random error
assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("random error")))
}
{
// Valid
assert.True(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("[FIELDS_ALREADY_EXISTS] Cannot add column, because `first_name` already exists]")))
}
}

func TestDatabricksDialect_IsTableDoesNotExistErr(t *testing.T) {
{
// No error
assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(nil))
}
{
// Random error
assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("random error")))
}
{
// Valid
assert.True(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("[TABLE_OR_VIEW_NOT_FOUND] Table or view not found: `foo`]")))
}
}

func TestDatabricksDialect_BuildCreateTableQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")

{
// Temporary
assert.Equal(t,
`CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`,
DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, true, []string{"{PART_1}", "{PART_2}"}),
)
}
{
// Not temporary
assert.Equal(t,
`CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`,
DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, false, []string{"{PART_1}", "{PART_2}"}),
)
}
}

func TestDatabricksDialect_BuildAlterColumnQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")

{
// DROP
assert.Equal(t, "ALTER TABLE {TABLE} drop COLUMN {SQL_PART}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Delete, "{SQL_PART}"))
}
{
// Add
assert.Equal(t, "ALTER TABLE {TABLE} add COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Add, "{SQL_PART} {DATA_TYPE}"))
}
}
82 changes: 82 additions & 0 deletions clients/databricks/dialect/typing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package dialect

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/ext"
)

func (DatabricksDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) string {
switch kindDetails.Kind {
case typing.Float.Kind:
return "DOUBLE"
case typing.Integer.Kind:
return "BIGINT"
case typing.Struct.Kind:
return "VARIANT"
case typing.Array.Kind:
// Databricks requires arrays to be typed. As such, we're going to use an array of strings.
return "ARRAY<string>"
case typing.String.Kind:
return "STRING"
case typing.Boolean.Kind:
return "BOOLEAN"
case typing.ETime.Kind:
switch kindDetails.ExtendedTimeDetails.Type {
case ext.TimestampTzKindType:
// Using datetime2 because it's the recommendation, and it provides more precision: https://stackoverflow.com/a/1884088
return "TIMESTAMP"
case ext.DateKindType:
return "DATE"
case ext.TimeKindType:
return "STRING"
}
case typing.EDecimal.Kind:
return kindDetails.ExtendedDecimalDetails.DatabricksKind()
}

return kindDetails.Kind
}

func (DatabricksDialect) KindForDataType(rawType string, _ string) (typing.KindDetails, error) {
rawType = strings.ToLower(rawType)
if strings.HasPrefix(rawType, "decimal") {
_, parameters, err := sql.ParseDataTypeDefinition(rawType)
if err != nil {
return typing.Invalid, err
}
return typing.ParseNumeric(parameters), nil
}

if strings.HasPrefix(rawType, "array") {
return typing.Array, nil
}

switch rawType {
case "string", "binary":
return typing.String, nil
case "bigint":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, nil
case "boolean":
return typing.Boolean, nil
case "date":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), nil
case "double", "float":
return typing.Float, nil
case "int":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, nil
case "smallint", "tinyint":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, nil
case "timestamp":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil
case "timestamp_ntz":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil
case "variant", "object":
return typing.Struct, nil
}

return typing.Invalid, fmt.Errorf("unsupported data type: %q", rawType)
}
Loading

0 comments on commit c6a2206

Please sign in to comment.