Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Databricks - Part Three #940

Merged
merged 16 commits into from
Oct 3, 2024
114 changes: 114 additions & 0 deletions clients/databricks/dialect/dialect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package dialect

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

type DatabricksDialect struct{}

func (DatabricksDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf("`%s`", identifier)
}

func (DatabricksDialect) EscapeStruct(value string) string {
panic("not implemented")
}

func (DatabricksDialect) IsColumnAlreadyExistsErr(err error) bool {
return err != nil && strings.Contains(err.Error(), "[FIELDS_ALREADY_EXISTS]")
}

func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool {
return err != nil && strings.Contains(err.Error(), "[TABLE_OR_VIEW_NOT_FOUND]")
}

func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string {
// Databricks doesn't have a concept of temporary tables.
return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", "))
}

func (DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string {
return fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", tableID.FullyQualifiedName(), columnOp, colSQLPart)
}

func (DatabricksDialect) BuildIsNotToastValueExpression(tableAlias constants.TableAlias, column columns.Column) string {
panic("not implemented")
}

func (DatabricksDialect) BuildDedupeTableQuery(tableID sql.TableIdentifier, primaryKeys []string) string {
panic("not implemented")
}

func (DatabricksDialect) BuildDedupeQueries(_, _ sql.TableIdentifier, _ []string, _ bool) []string {
panic("not implemented")
}

func (d DatabricksDialect) BuildMergeQueries(
tableID sql.TableIdentifier,
subQuery string,
primaryKeys []columns.Column,
additionalEqualityStrings []string,
cols []columns.Column,
softDelete bool,
_ bool,
) ([]string, error) {
// TODO: Add tests.

// Build the base equality condition for the MERGE query
equalitySQLParts := sql.BuildColumnComparisons(primaryKeys, constants.TargetAlias, constants.StagingAlias, sql.Equal, d)
if len(additionalEqualityStrings) > 0 {
equalitySQLParts = append(equalitySQLParts, additionalEqualityStrings...)
}

// Construct the base MERGE query
baseQuery := fmt.Sprintf(`MERGE INTO %s %s USING %s %s ON %s`, tableID.FullyQualifiedName(), constants.TargetAlias, subQuery, constants.StagingAlias, strings.Join(equalitySQLParts, " AND "))
// Remove columns with only the delete marker, as they are handled separately
cols, err := columns.RemoveOnlySetDeleteColumnMarker(cols)
if err != nil {
return nil, err
}

if softDelete {
// If softDelete is enabled, handle both update and soft-delete logic
return []string{baseQuery + fmt.Sprintf(`
WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s
WHEN MATCHED AND IFNULL(%s, false) = true THEN UPDATE SET %s
WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s);`,
sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d),
sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment([]columns.Column{columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)}, constants.StagingAlias, constants.TargetAlias, d),
strings.Join(sql.QuoteColumns(cols, d), ","),
strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","),
)}, nil
}

// Remove the delete marker for hard-delete logic
cols, err = columns.RemoveDeleteColumnMarker(cols)
if err != nil {
return nil, err
}

// Handle the case where hard-deletes are included
return []string{baseQuery + fmt.Sprintf(`
WHEN MATCHED AND %s THEN DELETE
WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s
WHEN NOT MATCHED AND IFNULL(%s, false) = false THEN INSERT (%s) VALUES (%s);`,
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d),
sql.QuotedDeleteColumnMarker(constants.StagingAlias, d),
strings.Join(sql.QuoteColumns(cols, d), ","),
strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","),
)}, nil
}

func (d DatabricksDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy {
return sql.Native
}
80 changes: 80 additions & 0 deletions clients/databricks/dialect/dialect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package dialect

import (
"fmt"
"testing"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/mocks"
"github.com/stretchr/testify/assert"
)

func TestDatabricksDialect_QuoteIdentifier(t *testing.T) {
dialect := DatabricksDialect{}
assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo"))
assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO"))
}

func TestDatabricksDialect_IsColumnAlreadyExistsErr(t *testing.T) {
{
// No error
assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(nil))
}
{
// Random error
assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("random error")))
}
{
// Valid
assert.True(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("[FIELDS_ALREADY_EXISTS] Cannot add column, because `first_name` already exists]")))
}
}

func TestDatabricksDialect_IsTableDoesNotExistErr(t *testing.T) {
{
// No error
assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(nil))
}
{
// Random error
assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("random error")))
}
{
// Valid
assert.True(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("[TABLE_OR_VIEW_NOT_FOUND] Table or view not found: `foo`]")))
}
}

func TestDatabricksDialect_BuildCreateTableQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")

{
// Temporary
assert.Equal(t,
`CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`,
DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, true, []string{"{PART_1}", "{PART_2}"}),
)
}
{
// Not temporary
assert.Equal(t,
`CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`,
DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, false, []string{"{PART_1}", "{PART_2}"}),
)
}
}

func TestDatabricksDialect_BuildAlterColumnQuery(t *testing.T) {
fakeTableID := &mocks.FakeTableIdentifier{}
fakeTableID.FullyQualifiedNameReturns("{TABLE}")

{
// DROP
assert.Equal(t, "ALTER TABLE {TABLE} drop COLUMN {SQL_PART}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Delete, "{SQL_PART}"))
}
{
// Add
assert.Equal(t, "ALTER TABLE {TABLE} add COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Add, "{SQL_PART} {DATA_TYPE}"))
}
}
82 changes: 82 additions & 0 deletions clients/databricks/dialect/typing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package dialect

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/sql"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/ext"
)

func (DatabricksDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) string {
switch kindDetails.Kind {
case typing.Float.Kind:
return "DOUBLE"
case typing.Integer.Kind:
return "BIGINT"
case typing.Struct.Kind:
return "VARIANT"
case typing.Array.Kind:
// Databricks requires arrays to be typed. As such, we're going to use an array of strings.
return "ARRAY<string>"
case typing.String.Kind:
return "STRING"
case typing.Boolean.Kind:
return "BOOLEAN"
case typing.ETime.Kind:
switch kindDetails.ExtendedTimeDetails.Type {
case ext.TimestampTzKindType:
// Using datetime2 because it's the recommendation, and it provides more precision: https://stackoverflow.com/a/1884088
return "TIMESTAMP"
case ext.DateKindType:
return "DATE"
case ext.TimeKindType:
return "STRING"
}
case typing.EDecimal.Kind:
return kindDetails.ExtendedDecimalDetails.DatabricksKind()
}

return kindDetails.Kind
}

func (DatabricksDialect) KindForDataType(rawType string, _ string) (typing.KindDetails, error) {
rawType = strings.ToLower(rawType)
if strings.HasPrefix(rawType, "decimal") {
_, parameters, err := sql.ParseDataTypeDefinition(rawType)
if err != nil {
return typing.Invalid, err
}
return typing.ParseNumeric(parameters), nil
}

if strings.HasPrefix(rawType, "array") {
return typing.Array, nil
}

switch rawType {
case "string", "binary":
return typing.String, nil
case "bigint":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, nil
case "boolean":
return typing.Boolean, nil
case "date":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), nil
case "double", "float":
return typing.Float, nil
case "int":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, nil
case "smallint", "tinyint":
return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, nil
case "timestamp":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil
case "timestamp_ntz":
return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil
case "variant", "object":
return typing.Struct, nil
}

return typing.Invalid, fmt.Errorf("unsupported data type: %q", rawType)
}
Loading
Loading