From c6a220672ee6947cba117e7f8d4d748259655ca8 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Thu, 3 Oct 2024 11:47:59 -0700 Subject: [PATCH] Supporting Databricks - Part Two (#939) --- clients/databricks/dialect/dialect.go | 114 ++++++++++++++ clients/databricks/dialect/dialect_test.go | 80 ++++++++++ clients/databricks/dialect/typing.go | 82 ++++++++++ clients/databricks/dialect/typing_test.go | 168 +++++++++++++++++++++ lib/typing/decimal/base.go | 8 + lib/typing/decimal/details.go | 6 + lib/typing/numeric.go | 1 + 7 files changed, 459 insertions(+) create mode 100644 clients/databricks/dialect/dialect.go create mode 100644 clients/databricks/dialect/dialect_test.go create mode 100644 clients/databricks/dialect/typing.go create mode 100644 clients/databricks/dialect/typing_test.go diff --git a/clients/databricks/dialect/dialect.go b/clients/databricks/dialect/dialect.go new file mode 100644 index 000000000..d895aff7a --- /dev/null +++ b/clients/databricks/dialect/dialect.go @@ -0,0 +1,114 @@ +package dialect + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" +) + +type DatabricksDialect struct{} + +func (DatabricksDialect) QuoteIdentifier(identifier string) string { + return fmt.Sprintf("`%s`", identifier) +} + +func (DatabricksDialect) EscapeStruct(value string) string { + panic("not implemented") +} + +func (DatabricksDialect) IsColumnAlreadyExistsErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "[FIELDS_ALREADY_EXISTS]") +} + +func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "[TABLE_OR_VIEW_NOT_FOUND]") +} + +func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string { + // Databricks doesn't have a concept of temporary tables. + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", ")) +} + +func (DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string { + return fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", tableID.FullyQualifiedName(), columnOp, colSQLPart) +} + +func (DatabricksDialect) BuildIsNotToastValueExpression(tableAlias constants.TableAlias, column columns.Column) string { + panic("not implemented") +} + +func (DatabricksDialect) BuildDedupeTableQuery(tableID sql.TableIdentifier, primaryKeys []string) string { + panic("not implemented") +} + +func (DatabricksDialect) BuildDedupeQueries(_, _ sql.TableIdentifier, _ []string, _ bool) []string { + panic("not implemented") +} + +func (d DatabricksDialect) BuildMergeQueries( + tableID sql.TableIdentifier, + subQuery string, + primaryKeys []columns.Column, + additionalEqualityStrings []string, + cols []columns.Column, + softDelete bool, + _ bool, +) ([]string, error) { + // TODO: Add tests. + + // Build the base equality condition for the MERGE query + equalitySQLParts := sql.BuildColumnComparisons(primaryKeys, constants.TargetAlias, constants.StagingAlias, sql.Equal, d) + if len(additionalEqualityStrings) > 0 { + equalitySQLParts = append(equalitySQLParts, additionalEqualityStrings...) + } + + // Construct the base MERGE query + baseQuery := fmt.Sprintf(`MERGE INTO %s %s USING %s %s ON %s`, tableID.FullyQualifiedName(), constants.TargetAlias, subQuery, constants.StagingAlias, strings.Join(equalitySQLParts, " AND ")) + // Remove columns with only the delete marker, as they are handled separately + cols, err := columns.RemoveOnlySetDeleteColumnMarker(cols) + if err != nil { + return nil, err + } + + if softDelete { + // If softDelete is enabled, handle both update and soft-delete logic + return []string{baseQuery + fmt.Sprintf(` +WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s +WHEN MATCHED AND IFNULL(%s, false) = true THEN UPDATE SET %s +WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s);`, + sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d), + sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment([]columns.Column{columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)}, constants.StagingAlias, constants.TargetAlias, d), + strings.Join(sql.QuoteColumns(cols, d), ","), + strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","), + )}, nil + } + + // Remove the delete marker for hard-delete logic + cols, err = columns.RemoveDeleteColumnMarker(cols) + if err != nil { + return nil, err + } + + // Handle the case where hard-deletes are included + return []string{baseQuery + fmt.Sprintf(` +WHEN MATCHED AND %s THEN DELETE +WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s +WHEN NOT MATCHED AND IFNULL(%s, false) = false THEN INSERT (%s) VALUES (%s);`, + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d), + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + strings.Join(sql.QuoteColumns(cols, d), ","), + strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","), + )}, nil +} + +func (d DatabricksDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { + return sql.Native +} diff --git a/clients/databricks/dialect/dialect_test.go b/clients/databricks/dialect/dialect_test.go new file mode 100644 index 000000000..6b2ea23ee --- /dev/null +++ b/clients/databricks/dialect/dialect_test.go @@ -0,0 +1,80 @@ +package dialect + +import ( + "fmt" + "testing" + + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/mocks" + "github.com/stretchr/testify/assert" +) + +func TestDatabricksDialect_QuoteIdentifier(t *testing.T) { + dialect := DatabricksDialect{} + assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo")) + assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO")) +} + +func TestDatabricksDialect_IsColumnAlreadyExistsErr(t *testing.T) { + { + // No error + assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(nil)) + } + { + // Random error + assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("random error"))) + } + { + // Valid + assert.True(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("[FIELDS_ALREADY_EXISTS] Cannot add column, because `first_name` already exists]"))) + } +} + +func TestDatabricksDialect_IsTableDoesNotExistErr(t *testing.T) { + { + // No error + assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(nil)) + } + { + // Random error + assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("random error"))) + } + { + // Valid + assert.True(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("[TABLE_OR_VIEW_NOT_FOUND] Table or view not found: `foo`]"))) + } +} + +func TestDatabricksDialect_BuildCreateTableQuery(t *testing.T) { + fakeTableID := &mocks.FakeTableIdentifier{} + fakeTableID.FullyQualifiedNameReturns("{TABLE}") + + { + // Temporary + assert.Equal(t, + `CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`, + DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, true, []string{"{PART_1}", "{PART_2}"}), + ) + } + { + // Not temporary + assert.Equal(t, + `CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`, + DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, false, []string{"{PART_1}", "{PART_2}"}), + ) + } +} + +func TestDatabricksDialect_BuildAlterColumnQuery(t *testing.T) { + fakeTableID := &mocks.FakeTableIdentifier{} + fakeTableID.FullyQualifiedNameReturns("{TABLE}") + + { + // DROP + assert.Equal(t, "ALTER TABLE {TABLE} drop COLUMN {SQL_PART}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Delete, "{SQL_PART}")) + } + { + // Add + assert.Equal(t, "ALTER TABLE {TABLE} add COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Add, "{SQL_PART} {DATA_TYPE}")) + } +} diff --git a/clients/databricks/dialect/typing.go b/clients/databricks/dialect/typing.go new file mode 100644 index 000000000..f70a5c99e --- /dev/null +++ b/clients/databricks/dialect/typing.go @@ -0,0 +1,82 @@ +package dialect + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/ext" +) + +func (DatabricksDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) string { + switch kindDetails.Kind { + case typing.Float.Kind: + return "DOUBLE" + case typing.Integer.Kind: + return "BIGINT" + case typing.Struct.Kind: + return "VARIANT" + case typing.Array.Kind: + // Databricks requires arrays to be typed. As such, we're going to use an array of strings. + return "ARRAY" + case typing.String.Kind: + return "STRING" + case typing.Boolean.Kind: + return "BOOLEAN" + case typing.ETime.Kind: + switch kindDetails.ExtendedTimeDetails.Type { + case ext.TimestampTzKindType: + // Using datetime2 because it's the recommendation, and it provides more precision: https://stackoverflow.com/a/1884088 + return "TIMESTAMP" + case ext.DateKindType: + return "DATE" + case ext.TimeKindType: + return "STRING" + } + case typing.EDecimal.Kind: + return kindDetails.ExtendedDecimalDetails.DatabricksKind() + } + + return kindDetails.Kind +} + +func (DatabricksDialect) KindForDataType(rawType string, _ string) (typing.KindDetails, error) { + rawType = strings.ToLower(rawType) + if strings.HasPrefix(rawType, "decimal") { + _, parameters, err := sql.ParseDataTypeDefinition(rawType) + if err != nil { + return typing.Invalid, err + } + return typing.ParseNumeric(parameters), nil + } + + if strings.HasPrefix(rawType, "array") { + return typing.Array, nil + } + + switch rawType { + case "string", "binary": + return typing.String, nil + case "bigint": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, nil + case "boolean": + return typing.Boolean, nil + case "date": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), nil + case "double", "float": + return typing.Float, nil + case "int": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, nil + case "smallint", "tinyint": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, nil + case "timestamp": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil + case "timestamp_ntz": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil + case "variant", "object": + return typing.Struct, nil + } + + return typing.Invalid, fmt.Errorf("unsupported data type: %q", rawType) +} diff --git a/clients/databricks/dialect/typing_test.go b/clients/databricks/dialect/typing_test.go new file mode 100644 index 000000000..51b0aedf3 --- /dev/null +++ b/clients/databricks/dialect/typing_test.go @@ -0,0 +1,168 @@ +package dialect + +import ( + "testing" + + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/decimal" + "github.com/artie-labs/transfer/lib/typing/ext" + "github.com/stretchr/testify/assert" +) + +func TestDatabricksDialect_DataTypeForKind(t *testing.T) { + { + // Float + assert.Equal(t, "DOUBLE", DatabricksDialect{}.DataTypeForKind(typing.Float, false)) + } + { + // Integer + assert.Equal(t, "BIGINT", DatabricksDialect{}.DataTypeForKind(typing.Integer, false)) + } + { + // Variant + assert.Equal(t, "VARIANT", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.Struct.Kind}, false)) + } + { + // Array + assert.Equal(t, "ARRAY", DatabricksDialect{}.DataTypeForKind(typing.Array, false)) + } + { + // String + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.String, false)) + } + { + // Boolean + assert.Equal(t, "BOOLEAN", DatabricksDialect{}.DataTypeForKind(typing.Boolean, false)) + } + { + // Times + { + // Date + assert.Equal(t, "DATE", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.DateKindType}}, false)) + } + { + // Timestamp + assert.Equal(t, "TIMESTAMP", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimestampTzKindType}}, false)) + } + { + // Timestamp (w/o timezone) + assert.Equal(t, "TIMESTAMP", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimestampTzKindType}}, false)) + } + { + // Time + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimeKindType}}, false)) + } + } + { + // Decimals + { + // Below 38 precision + assert.Equal(t, "DECIMAL(10, 2)", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(10, 2))}, false)) + } + { + // Above 38 precision + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(40, 2))}, false)) + } + } +} + +func TestDatabricksDialect_KindForDataType(t *testing.T) { + { + // Decimal + { + // Invalid + _, err := DatabricksDialect{}.KindForDataType("DECIMAL(9", "") + assert.ErrorContains(t, err, "missing closing parenthesis") + } + { + // Valid + kd, err := DatabricksDialect{}.KindForDataType("DECIMAL(10, 2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(10, 2))}, kd) + } + } + { + // Array + kd, err := DatabricksDialect{}.KindForDataType("ARRAY", "") + assert.NoError(t, err) + assert.Equal(t, typing.Array, kd) + } + { + // String + kd, err := DatabricksDialect{}.KindForDataType("STRING", "") + assert.NoError(t, err) + assert.Equal(t, typing.String, kd) + } + { + // Binary + kd, err := DatabricksDialect{}.KindForDataType("BINARY", "") + assert.NoError(t, err) + assert.Equal(t, typing.String, kd) + } + { + // BigInt + kd, err := DatabricksDialect{}.KindForDataType("BIGINT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, kd) + } + { + // Boolean + kd, err := DatabricksDialect{}.KindForDataType("BOOLEAN", "") + assert.NoError(t, err) + assert.Equal(t, typing.Boolean, kd) + } + { + // Date + kd, err := DatabricksDialect{}.KindForDataType("DATE", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.DateKindType}}, kd) + } + { + // Double + kd, err := DatabricksDialect{}.KindForDataType("DOUBLE", "") + assert.NoError(t, err) + assert.Equal(t, typing.Float, kd) + } + { + // Float + kd, err := DatabricksDialect{}.KindForDataType("FLOAT", "") + assert.NoError(t, err) + assert.Equal(t, typing.Float, kd) + } + { + // Integer + kd, err := DatabricksDialect{}.KindForDataType("INT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, kd) + } + { + // Small Int + kd, err := DatabricksDialect{}.KindForDataType("SMALLINT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, kd) + } + { + // Timestamp + kd, err := DatabricksDialect{}.KindForDataType("TIMESTAMP", "") + assert.NoError(t, err) + assert.Equal(t, typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), kd) + } + { + // Timestamp NTZ + kd, err := DatabricksDialect{}.KindForDataType("TIMESTAMP_NTZ", "") + assert.NoError(t, err) + assert.Equal(t, typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), kd) + } + { + // Variant + kd, err := DatabricksDialect{}.KindForDataType("VARIANT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Struct.Kind}, kd) + } + { + // Object + kd, err := DatabricksDialect{}.KindForDataType("OBJECT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Struct.Kind}, kd) + } +} diff --git a/lib/typing/decimal/base.go b/lib/typing/decimal/base.go index a6ac4f3eb..060eda28a 100644 --- a/lib/typing/decimal/base.go +++ b/lib/typing/decimal/base.go @@ -41,3 +41,11 @@ func (d Details) toKind(maxPrecision int32, exceededKind string) string { return fmt.Sprintf("NUMERIC(%v, %v)", d.precision, d.scale) } + +func (d Details) toDecimalKind(maxPrecision int32, exceededKind string) string { + if d.precision > maxPrecision || d.precision == PrecisionNotSpecified { + return exceededKind + } + + return fmt.Sprintf("DECIMAL(%d, %d)", d.precision, d.scale) +} diff --git a/lib/typing/decimal/details.go b/lib/typing/decimal/details.go index f9678822e..dd5018b9b 100644 --- a/lib/typing/decimal/details.go +++ b/lib/typing/decimal/details.go @@ -52,6 +52,12 @@ func (d Details) SnowflakeKind() string { return d.toKind(MaxPrecisionBeforeString, "STRING") } +// DatabricksKind - is used to determine whether a NUMERIC data type should be a STRING or NUMERIC(p, s). +// Ref: https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html +func (d Details) DatabricksKind() string { + return d.toDecimalKind(MaxPrecisionBeforeString, "STRING") +} + // MsSQLKind - Has the same limitation as Redshift // Spec: https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments func (d Details) MsSQLKind() string { diff --git a/lib/typing/numeric.go b/lib/typing/numeric.go index af86164c5..cc2ba7087 100644 --- a/lib/typing/numeric.go +++ b/lib/typing/numeric.go @@ -7,6 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/typing/decimal" ) +// TODO: This function should return an error func ParseNumeric(parts []string) KindDetails { if len(parts) == 0 || len(parts) > 2 { return Invalid