From 57af5766a5c0ecd4ed0749dca5449356461927c7 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Thu, 3 Oct 2024 10:27:11 -0700 Subject: [PATCH 1/3] [Typing] MSSQL - Use `VARCHAR(MAX)` instead of `TEXT` for decimals that have exceeded precision (#943) --- lib/typing/decimal/details.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/typing/decimal/details.go b/lib/typing/decimal/details.go index f48ee485b..f9678822e 100644 --- a/lib/typing/decimal/details.go +++ b/lib/typing/decimal/details.go @@ -55,7 +55,7 @@ func (d Details) SnowflakeKind() string { // MsSQLKind - Has the same limitation as Redshift // Spec: https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments func (d Details) MsSQLKind() string { - return d.toKind(MaxPrecisionBeforeString, "TEXT") + return d.toKind(MaxPrecisionBeforeString, "VARCHAR(MAX)") } // RedshiftKind - is used to determine whether a NUMERIC data type should be a TEXT or NUMERIC(p, s). From c6a220672ee6947cba117e7f8d4d748259655ca8 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Thu, 3 Oct 2024 11:47:59 -0700 Subject: [PATCH 2/3] Supporting Databricks - Part Two (#939) --- clients/databricks/dialect/dialect.go | 114 ++++++++++++++ clients/databricks/dialect/dialect_test.go | 80 ++++++++++ clients/databricks/dialect/typing.go | 82 ++++++++++ clients/databricks/dialect/typing_test.go | 168 +++++++++++++++++++++ lib/typing/decimal/base.go | 8 + lib/typing/decimal/details.go | 6 + lib/typing/numeric.go | 1 + 7 files changed, 459 insertions(+) create mode 100644 clients/databricks/dialect/dialect.go create mode 100644 clients/databricks/dialect/dialect_test.go create mode 100644 clients/databricks/dialect/typing.go create mode 100644 clients/databricks/dialect/typing_test.go diff --git a/clients/databricks/dialect/dialect.go b/clients/databricks/dialect/dialect.go new file mode 100644 index 000000000..d895aff7a --- /dev/null +++ b/clients/databricks/dialect/dialect.go @@ -0,0 +1,114 @@ +package dialect + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" +) + +type DatabricksDialect struct{} + +func (DatabricksDialect) QuoteIdentifier(identifier string) string { + return fmt.Sprintf("`%s`", identifier) +} + +func (DatabricksDialect) EscapeStruct(value string) string { + panic("not implemented") +} + +func (DatabricksDialect) IsColumnAlreadyExistsErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "[FIELDS_ALREADY_EXISTS]") +} + +func (DatabricksDialect) IsTableDoesNotExistErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "[TABLE_OR_VIEW_NOT_FOUND]") +} + +func (DatabricksDialect) BuildCreateTableQuery(tableID sql.TableIdentifier, _ bool, colSQLParts []string) string { + // Databricks doesn't have a concept of temporary tables. + return fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s)", tableID.FullyQualifiedName(), strings.Join(colSQLParts, ", ")) +} + +func (DatabricksDialect) BuildAlterColumnQuery(tableID sql.TableIdentifier, columnOp constants.ColumnOperation, colSQLPart string) string { + return fmt.Sprintf("ALTER TABLE %s %s COLUMN %s", tableID.FullyQualifiedName(), columnOp, colSQLPart) +} + +func (DatabricksDialect) BuildIsNotToastValueExpression(tableAlias constants.TableAlias, column columns.Column) string { + panic("not implemented") +} + +func (DatabricksDialect) BuildDedupeTableQuery(tableID sql.TableIdentifier, primaryKeys []string) string { + panic("not implemented") +} + +func (DatabricksDialect) BuildDedupeQueries(_, _ sql.TableIdentifier, _ []string, _ bool) []string { + panic("not implemented") +} + +func (d DatabricksDialect) BuildMergeQueries( + tableID sql.TableIdentifier, + subQuery string, + primaryKeys []columns.Column, + additionalEqualityStrings []string, + cols []columns.Column, + softDelete bool, + _ bool, +) ([]string, error) { + // TODO: Add tests. + + // Build the base equality condition for the MERGE query + equalitySQLParts := sql.BuildColumnComparisons(primaryKeys, constants.TargetAlias, constants.StagingAlias, sql.Equal, d) + if len(additionalEqualityStrings) > 0 { + equalitySQLParts = append(equalitySQLParts, additionalEqualityStrings...) + } + + // Construct the base MERGE query + baseQuery := fmt.Sprintf(`MERGE INTO %s %s USING %s %s ON %s`, tableID.FullyQualifiedName(), constants.TargetAlias, subQuery, constants.StagingAlias, strings.Join(equalitySQLParts, " AND ")) + // Remove columns with only the delete marker, as they are handled separately + cols, err := columns.RemoveOnlySetDeleteColumnMarker(cols) + if err != nil { + return nil, err + } + + if softDelete { + // If softDelete is enabled, handle both update and soft-delete logic + return []string{baseQuery + fmt.Sprintf(` +WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s +WHEN MATCHED AND IFNULL(%s, false) = true THEN UPDATE SET %s +WHEN NOT MATCHED THEN INSERT (%s) VALUES (%s);`, + sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d), + sql.GetQuotedOnlySetDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment([]columns.Column{columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean)}, constants.StagingAlias, constants.TargetAlias, d), + strings.Join(sql.QuoteColumns(cols, d), ","), + strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","), + )}, nil + } + + // Remove the delete marker for hard-delete logic + cols, err = columns.RemoveDeleteColumnMarker(cols) + if err != nil { + return nil, err + } + + // Handle the case where hard-deletes are included + return []string{baseQuery + fmt.Sprintf(` +WHEN MATCHED AND %s THEN DELETE +WHEN MATCHED AND IFNULL(%s, false) = false THEN UPDATE SET %s +WHEN NOT MATCHED AND IFNULL(%s, false) = false THEN INSERT (%s) VALUES (%s);`, + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + sql.BuildColumnsUpdateFragment(cols, constants.StagingAlias, constants.TargetAlias, d), + sql.QuotedDeleteColumnMarker(constants.StagingAlias, d), + strings.Join(sql.QuoteColumns(cols, d), ","), + strings.Join(sql.QuoteTableAliasColumns(constants.StagingAlias, cols, d), ","), + )}, nil +} + +func (d DatabricksDialect) GetDefaultValueStrategy() sql.DefaultValueStrategy { + return sql.Native +} diff --git a/clients/databricks/dialect/dialect_test.go b/clients/databricks/dialect/dialect_test.go new file mode 100644 index 000000000..6b2ea23ee --- /dev/null +++ b/clients/databricks/dialect/dialect_test.go @@ -0,0 +1,80 @@ +package dialect + +import ( + "fmt" + "testing" + + "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/mocks" + "github.com/stretchr/testify/assert" +) + +func TestDatabricksDialect_QuoteIdentifier(t *testing.T) { + dialect := DatabricksDialect{} + assert.Equal(t, "`foo`", dialect.QuoteIdentifier("foo")) + assert.Equal(t, "`FOO`", dialect.QuoteIdentifier("FOO")) +} + +func TestDatabricksDialect_IsColumnAlreadyExistsErr(t *testing.T) { + { + // No error + assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(nil)) + } + { + // Random error + assert.False(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("random error"))) + } + { + // Valid + assert.True(t, DatabricksDialect{}.IsColumnAlreadyExistsErr(fmt.Errorf("[FIELDS_ALREADY_EXISTS] Cannot add column, because `first_name` already exists]"))) + } +} + +func TestDatabricksDialect_IsTableDoesNotExistErr(t *testing.T) { + { + // No error + assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(nil)) + } + { + // Random error + assert.False(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("random error"))) + } + { + // Valid + assert.True(t, DatabricksDialect{}.IsTableDoesNotExistErr(fmt.Errorf("[TABLE_OR_VIEW_NOT_FOUND] Table or view not found: `foo`]"))) + } +} + +func TestDatabricksDialect_BuildCreateTableQuery(t *testing.T) { + fakeTableID := &mocks.FakeTableIdentifier{} + fakeTableID.FullyQualifiedNameReturns("{TABLE}") + + { + // Temporary + assert.Equal(t, + `CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`, + DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, true, []string{"{PART_1}", "{PART_2}"}), + ) + } + { + // Not temporary + assert.Equal(t, + `CREATE TABLE IF NOT EXISTS {TABLE} ({PART_1}, {PART_2})`, + DatabricksDialect{}.BuildCreateTableQuery(fakeTableID, false, []string{"{PART_1}", "{PART_2}"}), + ) + } +} + +func TestDatabricksDialect_BuildAlterColumnQuery(t *testing.T) { + fakeTableID := &mocks.FakeTableIdentifier{} + fakeTableID.FullyQualifiedNameReturns("{TABLE}") + + { + // DROP + assert.Equal(t, "ALTER TABLE {TABLE} drop COLUMN {SQL_PART}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Delete, "{SQL_PART}")) + } + { + // Add + assert.Equal(t, "ALTER TABLE {TABLE} add COLUMN {SQL_PART} {DATA_TYPE}", DatabricksDialect{}.BuildAlterColumnQuery(fakeTableID, constants.Add, "{SQL_PART} {DATA_TYPE}")) + } +} diff --git a/clients/databricks/dialect/typing.go b/clients/databricks/dialect/typing.go new file mode 100644 index 000000000..f70a5c99e --- /dev/null +++ b/clients/databricks/dialect/typing.go @@ -0,0 +1,82 @@ +package dialect + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/sql" + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/ext" +) + +func (DatabricksDialect) DataTypeForKind(kindDetails typing.KindDetails, _ bool) string { + switch kindDetails.Kind { + case typing.Float.Kind: + return "DOUBLE" + case typing.Integer.Kind: + return "BIGINT" + case typing.Struct.Kind: + return "VARIANT" + case typing.Array.Kind: + // Databricks requires arrays to be typed. As such, we're going to use an array of strings. + return "ARRAY" + case typing.String.Kind: + return "STRING" + case typing.Boolean.Kind: + return "BOOLEAN" + case typing.ETime.Kind: + switch kindDetails.ExtendedTimeDetails.Type { + case ext.TimestampTzKindType: + // Using datetime2 because it's the recommendation, and it provides more precision: https://stackoverflow.com/a/1884088 + return "TIMESTAMP" + case ext.DateKindType: + return "DATE" + case ext.TimeKindType: + return "STRING" + } + case typing.EDecimal.Kind: + return kindDetails.ExtendedDecimalDetails.DatabricksKind() + } + + return kindDetails.Kind +} + +func (DatabricksDialect) KindForDataType(rawType string, _ string) (typing.KindDetails, error) { + rawType = strings.ToLower(rawType) + if strings.HasPrefix(rawType, "decimal") { + _, parameters, err := sql.ParseDataTypeDefinition(rawType) + if err != nil { + return typing.Invalid, err + } + return typing.ParseNumeric(parameters), nil + } + + if strings.HasPrefix(rawType, "array") { + return typing.Array, nil + } + + switch rawType { + case "string", "binary": + return typing.String, nil + case "bigint": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, nil + case "boolean": + return typing.Boolean, nil + case "date": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.DateKindType), nil + case "double", "float": + return typing.Float, nil + case "int": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, nil + case "smallint", "tinyint": + return typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, nil + case "timestamp": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil + case "timestamp_ntz": + return typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), nil + case "variant", "object": + return typing.Struct, nil + } + + return typing.Invalid, fmt.Errorf("unsupported data type: %q", rawType) +} diff --git a/clients/databricks/dialect/typing_test.go b/clients/databricks/dialect/typing_test.go new file mode 100644 index 000000000..51b0aedf3 --- /dev/null +++ b/clients/databricks/dialect/typing_test.go @@ -0,0 +1,168 @@ +package dialect + +import ( + "testing" + + "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/decimal" + "github.com/artie-labs/transfer/lib/typing/ext" + "github.com/stretchr/testify/assert" +) + +func TestDatabricksDialect_DataTypeForKind(t *testing.T) { + { + // Float + assert.Equal(t, "DOUBLE", DatabricksDialect{}.DataTypeForKind(typing.Float, false)) + } + { + // Integer + assert.Equal(t, "BIGINT", DatabricksDialect{}.DataTypeForKind(typing.Integer, false)) + } + { + // Variant + assert.Equal(t, "VARIANT", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.Struct.Kind}, false)) + } + { + // Array + assert.Equal(t, "ARRAY", DatabricksDialect{}.DataTypeForKind(typing.Array, false)) + } + { + // String + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.String, false)) + } + { + // Boolean + assert.Equal(t, "BOOLEAN", DatabricksDialect{}.DataTypeForKind(typing.Boolean, false)) + } + { + // Times + { + // Date + assert.Equal(t, "DATE", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.DateKindType}}, false)) + } + { + // Timestamp + assert.Equal(t, "TIMESTAMP", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimestampTzKindType}}, false)) + } + { + // Timestamp (w/o timezone) + assert.Equal(t, "TIMESTAMP", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimestampTzKindType}}, false)) + } + { + // Time + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.TimeKindType}}, false)) + } + } + { + // Decimals + { + // Below 38 precision + assert.Equal(t, "DECIMAL(10, 2)", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(10, 2))}, false)) + } + { + // Above 38 precision + assert.Equal(t, "STRING", DatabricksDialect{}.DataTypeForKind(typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(40, 2))}, false)) + } + } +} + +func TestDatabricksDialect_KindForDataType(t *testing.T) { + { + // Decimal + { + // Invalid + _, err := DatabricksDialect{}.KindForDataType("DECIMAL(9", "") + assert.ErrorContains(t, err, "missing closing parenthesis") + } + { + // Valid + kd, err := DatabricksDialect{}.KindForDataType("DECIMAL(10, 2)", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.EDecimal.Kind, ExtendedDecimalDetails: typing.ToPtr(decimal.NewDetails(10, 2))}, kd) + } + } + { + // Array + kd, err := DatabricksDialect{}.KindForDataType("ARRAY", "") + assert.NoError(t, err) + assert.Equal(t, typing.Array, kd) + } + { + // String + kd, err := DatabricksDialect{}.KindForDataType("STRING", "") + assert.NoError(t, err) + assert.Equal(t, typing.String, kd) + } + { + // Binary + kd, err := DatabricksDialect{}.KindForDataType("BINARY", "") + assert.NoError(t, err) + assert.Equal(t, typing.String, kd) + } + { + // BigInt + kd, err := DatabricksDialect{}.KindForDataType("BIGINT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.BigIntegerKind)}, kd) + } + { + // Boolean + kd, err := DatabricksDialect{}.KindForDataType("BOOLEAN", "") + assert.NoError(t, err) + assert.Equal(t, typing.Boolean, kd) + } + { + // Date + kd, err := DatabricksDialect{}.KindForDataType("DATE", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.ETime.Kind, ExtendedTimeDetails: &ext.NestedKind{Type: ext.DateKindType}}, kd) + } + { + // Double + kd, err := DatabricksDialect{}.KindForDataType("DOUBLE", "") + assert.NoError(t, err) + assert.Equal(t, typing.Float, kd) + } + { + // Float + kd, err := DatabricksDialect{}.KindForDataType("FLOAT", "") + assert.NoError(t, err) + assert.Equal(t, typing.Float, kd) + } + { + // Integer + kd, err := DatabricksDialect{}.KindForDataType("INT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.IntegerKind)}, kd) + } + { + // Small Int + kd, err := DatabricksDialect{}.KindForDataType("SMALLINT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Integer.Kind, OptionalIntegerKind: typing.ToPtr(typing.SmallIntegerKind)}, kd) + } + { + // Timestamp + kd, err := DatabricksDialect{}.KindForDataType("TIMESTAMP", "") + assert.NoError(t, err) + assert.Equal(t, typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), kd) + } + { + // Timestamp NTZ + kd, err := DatabricksDialect{}.KindForDataType("TIMESTAMP_NTZ", "") + assert.NoError(t, err) + assert.Equal(t, typing.NewKindDetailsFromTemplate(typing.ETime, ext.TimestampTzKindType), kd) + } + { + // Variant + kd, err := DatabricksDialect{}.KindForDataType("VARIANT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Struct.Kind}, kd) + } + { + // Object + kd, err := DatabricksDialect{}.KindForDataType("OBJECT", "") + assert.NoError(t, err) + assert.Equal(t, typing.KindDetails{Kind: typing.Struct.Kind}, kd) + } +} diff --git a/lib/typing/decimal/base.go b/lib/typing/decimal/base.go index a6ac4f3eb..060eda28a 100644 --- a/lib/typing/decimal/base.go +++ b/lib/typing/decimal/base.go @@ -41,3 +41,11 @@ func (d Details) toKind(maxPrecision int32, exceededKind string) string { return fmt.Sprintf("NUMERIC(%v, %v)", d.precision, d.scale) } + +func (d Details) toDecimalKind(maxPrecision int32, exceededKind string) string { + if d.precision > maxPrecision || d.precision == PrecisionNotSpecified { + return exceededKind + } + + return fmt.Sprintf("DECIMAL(%d, %d)", d.precision, d.scale) +} diff --git a/lib/typing/decimal/details.go b/lib/typing/decimal/details.go index f9678822e..dd5018b9b 100644 --- a/lib/typing/decimal/details.go +++ b/lib/typing/decimal/details.go @@ -52,6 +52,12 @@ func (d Details) SnowflakeKind() string { return d.toKind(MaxPrecisionBeforeString, "STRING") } +// DatabricksKind - is used to determine whether a NUMERIC data type should be a STRING or NUMERIC(p, s). +// Ref: https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html +func (d Details) DatabricksKind() string { + return d.toDecimalKind(MaxPrecisionBeforeString, "STRING") +} + // MsSQLKind - Has the same limitation as Redshift // Spec: https://learn.microsoft.com/en-us/sql/t-sql/data-types/decimal-and-numeric-transact-sql?view=sql-server-ver16#arguments func (d Details) MsSQLKind() string { diff --git a/lib/typing/numeric.go b/lib/typing/numeric.go index af86164c5..cc2ba7087 100644 --- a/lib/typing/numeric.go +++ b/lib/typing/numeric.go @@ -7,6 +7,7 @@ import ( "github.com/artie-labs/transfer/lib/typing/decimal" ) +// TODO: This function should return an error func ParseNumeric(parts []string) KindDetails { if len(parts) == 0 || len(parts) > 2 { return Invalid From f498888e9cd5b0c12d9b54e8e6800c3d8e6a36b0 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Thu, 3 Oct 2024 11:50:51 -0700 Subject: [PATCH 3/3] Supporting Databricks - Part Three (#940) --- clients/databricks/tableid.go | 48 ++++++++++++++++++++++++++++++ clients/databricks/tableid_test.go | 33 ++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 clients/databricks/tableid.go create mode 100644 clients/databricks/tableid_test.go diff --git a/clients/databricks/tableid.go b/clients/databricks/tableid.go new file mode 100644 index 000000000..df1313063 --- /dev/null +++ b/clients/databricks/tableid.go @@ -0,0 +1,48 @@ +package databricks + +import ( + "fmt" + + "github.com/artie-labs/transfer/clients/databricks/dialect" + "github.com/artie-labs/transfer/lib/sql" +) + +var _dialect = dialect.DatabricksDialect{} + +type TableIdentifier struct { + database string + schema string + table string +} + +func NewTableIdentifier(database, schema, table string) TableIdentifier { + return TableIdentifier{ + database: database, + schema: schema, + table: table, + } +} + +func (ti TableIdentifier) Database() string { + return ti.database +} + +func (ti TableIdentifier) Schema() string { + return ti.schema +} + +func (ti TableIdentifier) EscapedTable() string { + return _dialect.QuoteIdentifier(ti.table) +} + +func (ti TableIdentifier) Table() string { + return ti.table +} + +func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { + return NewTableIdentifier(ti.database, ti.schema, table) +} + +func (ti TableIdentifier) FullyQualifiedName() string { + return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) +} diff --git a/clients/databricks/tableid_test.go b/clients/databricks/tableid_test.go new file mode 100644 index 000000000..13d756c2d --- /dev/null +++ b/clients/databricks/tableid_test.go @@ -0,0 +1,33 @@ +package databricks + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTableIdentifier_WithTable(t *testing.T) { + tableID := NewTableIdentifier("database", "schema", "foo") + tableID2 := tableID.WithTable("bar") + typedTableID2, ok := tableID2.(TableIdentifier) + assert.True(t, ok) + assert.Equal(t, "database", typedTableID2.Database()) + assert.Equal(t, "schema", typedTableID2.Schema()) + assert.Equal(t, "bar", tableID2.Table()) +} + +func TestTableIdentifier_FullyQualifiedName(t *testing.T) { + // Table name that is not a reserved word: + assert.Equal(t, "`database`.`schema`.`foo`", NewTableIdentifier("database", "schema", "foo").FullyQualifiedName()) + + // Table name that is a reserved word: + assert.Equal(t, "`database`.`schema`.`table`", NewTableIdentifier("database", "schema", "table").FullyQualifiedName()) +} + +func TestTableIdentifier_EscapedTable(t *testing.T) { + // Table name that is not a reserved word: + assert.Equal(t, "`foo`", NewTableIdentifier("database", "schema", "foo").EscapedTable()) + + // Table name that is a reserved word: + assert.Equal(t, "`table`", NewTableIdentifier("database", "schema", "table").EscapedTable()) +}