From f498888e9cd5b0c12d9b54e8e6800c3d8e6a36b0 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Thu, 3 Oct 2024 11:50:51 -0700 Subject: [PATCH] Supporting Databricks - Part Three (#940) --- clients/databricks/tableid.go | 48 ++++++++++++++++++++++++++++++ clients/databricks/tableid_test.go | 33 ++++++++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 clients/databricks/tableid.go create mode 100644 clients/databricks/tableid_test.go diff --git a/clients/databricks/tableid.go b/clients/databricks/tableid.go new file mode 100644 index 000000000..df1313063 --- /dev/null +++ b/clients/databricks/tableid.go @@ -0,0 +1,48 @@ +package databricks + +import ( + "fmt" + + "github.com/artie-labs/transfer/clients/databricks/dialect" + "github.com/artie-labs/transfer/lib/sql" +) + +var _dialect = dialect.DatabricksDialect{} + +type TableIdentifier struct { + database string + schema string + table string +} + +func NewTableIdentifier(database, schema, table string) TableIdentifier { + return TableIdentifier{ + database: database, + schema: schema, + table: table, + } +} + +func (ti TableIdentifier) Database() string { + return ti.database +} + +func (ti TableIdentifier) Schema() string { + return ti.schema +} + +func (ti TableIdentifier) EscapedTable() string { + return _dialect.QuoteIdentifier(ti.table) +} + +func (ti TableIdentifier) Table() string { + return ti.table +} + +func (ti TableIdentifier) WithTable(table string) sql.TableIdentifier { + return NewTableIdentifier(ti.database, ti.schema, table) +} + +func (ti TableIdentifier) FullyQualifiedName() string { + return fmt.Sprintf("%s.%s.%s", _dialect.QuoteIdentifier(ti.database), _dialect.QuoteIdentifier(ti.schema), ti.EscapedTable()) +} diff --git a/clients/databricks/tableid_test.go b/clients/databricks/tableid_test.go new file mode 100644 index 000000000..13d756c2d --- /dev/null +++ b/clients/databricks/tableid_test.go @@ -0,0 +1,33 @@ +package databricks + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTableIdentifier_WithTable(t *testing.T) { + tableID := NewTableIdentifier("database", "schema", "foo") + tableID2 := tableID.WithTable("bar") + typedTableID2, ok := tableID2.(TableIdentifier) + assert.True(t, ok) + assert.Equal(t, "database", typedTableID2.Database()) + assert.Equal(t, "schema", typedTableID2.Schema()) + assert.Equal(t, "bar", tableID2.Table()) +} + +func TestTableIdentifier_FullyQualifiedName(t *testing.T) { + // Table name that is not a reserved word: + assert.Equal(t, "`database`.`schema`.`foo`", NewTableIdentifier("database", "schema", "foo").FullyQualifiedName()) + + // Table name that is a reserved word: + assert.Equal(t, "`database`.`schema`.`table`", NewTableIdentifier("database", "schema", "table").FullyQualifiedName()) +} + +func TestTableIdentifier_EscapedTable(t *testing.T) { + // Table name that is not a reserved word: + assert.Equal(t, "`foo`", NewTableIdentifier("database", "schema", "foo").EscapedTable()) + + // Table name that is a reserved word: + assert.Equal(t, "`table`", NewTableIdentifier("database", "schema", "table").EscapedTable()) +}