Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed Apr 22, 2024
1 parent ef9814c commit 3825947
Show file tree
Hide file tree
Showing 15 changed files with 57 additions and 59 deletions.
7 changes: 5 additions & 2 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier {
return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table, s.ShouldUppercaseEscapedNames())
return NewTableIdentifier(s.config.BigQuery.ProjectID, topicConfig.Database, table)
}

func (s *Store) GetTableConfig(tableData *optimization.TableData) (*types.DwhTableConfig, error) {
Expand Down Expand Up @@ -112,7 +112,10 @@ func (s *Store) Label() constants.DestinationKind {
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return s.config.SharedDestinationConfig.UppercaseEscapedNames
if s.config.SharedDestinationConfig.UppercaseEscapedNames {
panic("UppercaseEscapedNames is not supported for BigQuery")
}
return false
}

func (s *Store) GetClient(ctx context.Context) *bigquery.Client {
Expand Down
2 changes: 1 addition & 1 deletion clients/bigquery/merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func (b *BigQueryTestSuite) TestBackfillColumn() {
tableID := NewTableIdentifier("db", "public", "tableName", false)
tableID := NewTableIdentifier("db", "public", "tableName")
type _testCase struct {
name string
col columns.Column
Expand Down
20 changes: 9 additions & 11 deletions clients/bigquery/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,16 @@ import (
)

type TableIdentifier struct {
projectID string
dataset string
table string
uppercaseEscapedNames bool
projectID string
dataset string
table string
}

func NewTableIdentifier(projectID, dataset, table string, uppercaseEscapedNames bool) TableIdentifier {
func NewTableIdentifier(projectID, dataset, table string) TableIdentifier {
return TableIdentifier{
projectID: projectID,
dataset: dataset,
table: table,
uppercaseEscapedNames: uppercaseEscapedNames,
projectID: projectID,
dataset: dataset,
table: table,
}
}

Expand All @@ -37,7 +35,7 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.projectID, ti.dataset, table, ti.uppercaseEscapedNames)
return NewTableIdentifier(ti.projectID, ti.dataset, table)
}

func (ti TableIdentifier) FullyQualifiedName() string {
Expand All @@ -47,6 +45,6 @@ func (ti TableIdentifier) FullyQualifiedName() string {
"`%s`.`%s`.%s",
ti.projectID,
ti.dataset,
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
sql.EscapeNameIfNecessary(ti.table, false, &sql.NameArgs{Escape: true, DestKind: constants.BigQuery}),
)
}
9 changes: 3 additions & 6 deletions clients/bigquery/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("project", "dataset", "foo", true)
tableID := NewTableIdentifier("project", "dataset", "foo")
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "project", typedTableID2.ProjectID())
assert.Equal(t, "dataset", typedTableID2.Dataset())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
// Table name that does not need escaping:
assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", false).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo", true).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.foo", NewTableIdentifier("project", "dataset", "foo").FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table", false).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.`TABLE`", NewTableIdentifier("project", "dataset", "table", true).FullyQualifiedName())
assert.Equal(t, "`project`.`dataset`.`table`", NewTableIdentifier("project", "dataset", "table").FullyQualifiedName())
}
7 changes: 5 additions & 2 deletions clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (s *Store) Label() constants.DestinationKind {
}

func (s *Store) ShouldUppercaseEscapedNames() bool {
return s.config.SharedDestinationConfig.UppercaseEscapedNames
if s.config.SharedDestinationConfig.UppercaseEscapedNames {
panic("UppercaseEscapedNames is not supported for MSSQL")
}
return false
}

func (s *Store) Merge(tableData *optimization.TableData) error {
Expand All @@ -49,7 +52,7 @@ func (s *Store) Append(tableData *optimization.TableData) error {

// specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name.
func (s *Store) specificIdentifierFor(topicConfig kafkalib.TopicConfig, table string) TableIdentifier {
return NewTableIdentifier(getSchema(topicConfig.Schema), table, s.ShouldUppercaseEscapedNames())
return NewTableIdentifier(getSchema(topicConfig.Schema), table)
}

// IdentifierFor returns a generic [types.TableIdentifier] interface for a [TopicConfig] + table name.
Expand Down
13 changes: 6 additions & 7 deletions clients/mssql/tableid.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ import (
)

type TableIdentifier struct {
schema string
table string
uppercaseEscapedNames bool
schema string
table string
}

func NewTableIdentifier(schema, table string, uppercaseEscapedNames bool) TableIdentifier {
return TableIdentifier{schema: schema, table: table, uppercaseEscapedNames: uppercaseEscapedNames}
func NewTableIdentifier(schema, table string) TableIdentifier {
return TableIdentifier{schema: schema, table: table}
}

func (ti TableIdentifier) Schema() string {
Expand All @@ -27,13 +26,13 @@ func (ti TableIdentifier) Table() string {
}

func (ti TableIdentifier) WithTable(table string) types.TableIdentifier {
return NewTableIdentifier(ti.schema, table, ti.uppercaseEscapedNames)
return NewTableIdentifier(ti.schema, table)
}

func (ti TableIdentifier) FullyQualifiedName() string {
return fmt.Sprintf(
"%s.%s",
ti.schema,
sql.EscapeNameIfNecessary(ti.table, ti.uppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
sql.EscapeNameIfNecessary(ti.table, false, &sql.NameArgs{Escape: true, DestKind: constants.MSSQL}),
)
}
9 changes: 3 additions & 6 deletions clients/mssql/tableid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,18 @@ import (
)

func TestTableIdentifier_WithTable(t *testing.T) {
tableID := NewTableIdentifier("schema", "foo", true)
tableID := NewTableIdentifier("schema", "foo")
tableID2 := tableID.WithTable("bar")
typedTableID2, ok := tableID2.(TableIdentifier)
assert.True(t, ok)
assert.Equal(t, "schema", typedTableID2.Schema())
assert.Equal(t, "bar", tableID2.Table())
assert.True(t, typedTableID2.uppercaseEscapedNames)
}

func TestTableIdentifier_FullyQualifiedName(t *testing.T) {
// Table name that does not need escaping:
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", false).FullyQualifiedName())
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo", true).FullyQualifiedName())
assert.Equal(t, "schema.foo", NewTableIdentifier("schema", "foo").FullyQualifiedName())

// Table name that needs escaping:
assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table", false).FullyQualifiedName())
assert.Equal(t, `schema."TABLE"`, NewTableIdentifier("schema", "table", true).FullyQualifiedName())
assert.Equal(t, `schema."table"`, NewTableIdentifier("schema", "table").FullyQualifiedName())
}
3 changes: 2 additions & 1 deletion clients/shared/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/artie-labs/transfer/lib/destination/ddl"
"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/optimization"
"github.com/artie-labs/transfer/lib/ptr"
"github.com/artie-labs/transfer/lib/typing/columns"
)

Expand All @@ -35,7 +36,7 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, cf
CreateTable: tableConfig.CreateTable(),
ColumnOp: constants.Add,
CdcTime: tableData.LatestCDCTs,
UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames,
UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
}

Expand Down
8 changes: 4 additions & 4 deletions clients/shared/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
CreateTable: tableConfig.CreateTable(),
ColumnOp: constants.Add,
CdcTime: tableData.LatestCDCTs,
UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames,
UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
}

Expand All @@ -61,7 +61,7 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
ColumnOp: constants.Delete,
ContainOtherOperations: tableData.ContainOtherOperations(),
CdcTime: tableData.LatestCDCTs,
UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames,
UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()),
Mode: tableData.Mode(),
}

Expand Down Expand Up @@ -123,11 +123,11 @@ func Merge(dwh destination.DataWarehouse, tableData *optimization.TableData, cfg
TableID: tableID,
SubQuery: subQuery,
IdempotentKey: tableData.TopicConfig().IdempotentKey,
PrimaryKeys: tableData.PrimaryKeys(cfg.SharedDestinationConfig.UppercaseEscapedNames, &sql.NameArgs{Escape: true, DestKind: dwh.Label()}),
PrimaryKeys: tableData.PrimaryKeys(dwh.ShouldUppercaseEscapedNames(), &sql.NameArgs{Escape: true, DestKind: dwh.Label()}),
Columns: tableData.ReadOnlyInMemoryCols(),
SoftDelete: tableData.TopicConfig().SoftDelete,
DestKind: dwh.Label(),
UppercaseEscNames: &cfg.SharedDestinationConfig.UppercaseEscapedNames,
UppercaseEscNames: ptr.ToBool(dwh.ShouldUppercaseEscapedNames()),
ContainsHardDeletes: ptr.ToBool(tableData.ContainsHardDeletes()),
}

Expand Down
2 changes: 1 addition & 1 deletion clients/shared/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func BackfillColumn(cfg config.Config, dwh destination.DataWarehouse, column col
return fmt.Errorf("failed to escape default value: %w", err)
}

uppercaseEscNames := cfg.SharedDestinationConfig.UppercaseEscapedNames
uppercaseEscNames := dwh.ShouldUppercaseEscapedNames()
escapedCol := column.Name(uppercaseEscNames, &sql.NameArgs{Escape: true, DestKind: dwh.Label()})

// TODO: This is added because `default` is not technically a column that requires escaping, but it is required when it's in the where clause.
Expand Down
4 changes: 2 additions & 2 deletions lib/destination/ddl/ddl_bq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (d *DDLTestSuite) TestAlterTableDropColumnsBigQuery() {
}

func (d *DDLTestSuite) TestAlterTableAddColumns() {
tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true)
tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols")
fqName := tableID.FullyQualifiedName()
ts := time.Now()
existingColNameToKindDetailsMap := map[string]typing.KindDetails{
Expand Down Expand Up @@ -176,7 +176,7 @@ func (d *DDLTestSuite) TestAlterTableAddColumns() {
}

func (d *DDLTestSuite) TestAlterTableAddColumnsSomeAlreadyExist() {
tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols", true)
tableID := bigquery.NewTableIdentifier("", "mock_dataset", "add_cols")
fqName := tableID.FullyQualifiedName()
ts := time.Now()
existingColNameToKindDetailsMap := map[string]typing.KindDetails{
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/ddl/ddl_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
)

func (d *DDLTestSuite) Test_CreateTable() {
bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table", true)
bqTableID := bigquery.NewTableIdentifier("", "mock_dataset", "mock_table")
d.bigQueryStore.GetConfigMap().AddTableToConfig(bqTableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))

snowflakeTableID := snowflake.NewTableIdentifier("", "mock_dataset", "mock_table", true)
Expand Down
2 changes: 1 addition & 1 deletion lib/destination/ddl/ddl_temp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (d *DDLTestSuite) TestCreateTemporaryTable() {
}
{
// BigQuery
tableID := bigquery.NewTableIdentifier("db", "schema", "tempTableName", false)
tableID := bigquery.NewTableIdentifier("db", "schema", "tempTableName")
d.bigQueryStore.GetConfigMap().AddTableToConfig(tableID, types.NewDwhTableConfig(&columns.Columns{}, nil, true, true))
bqTc := d.bigQueryStore.GetConfigMap().TableConfig(tableID)
args := ddl.AlterTableArgs{
Expand Down
26 changes: 13 additions & 13 deletions lib/sql/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,7 @@ type NameArgs struct {
// symbolsToEscape are additional keywords that we need to escape
var symbolsToEscape = []string{":"}

func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string {
if args == nil || !args.Escape {
return name
}

if needsEscaping(name, args.DestKind) {
return escapeName(name, uppercaseEscNames, args.DestKind)
}

return name
}

func needsEscaping(name string, destKind constants.DestinationKind) bool {
func needsEscaping(name string, uppercaseEscNames bool, destKind constants.DestinationKind) bool {
var reservedKeywords []string
if destKind == constants.Redshift {
reservedKeywords = constants.RedshiftReservedKeywords
Expand Down Expand Up @@ -61,6 +49,18 @@ func needsEscaping(name string, destKind constants.DestinationKind) bool {
return needsEscaping
}

func EscapeNameIfNecessary(name string, uppercaseEscNames bool, args *NameArgs) string {
if args == nil || !args.Escape {
return name
}

if needsEscaping(name, uppercaseEscNames, args.DestKind) {
return escapeName(name, uppercaseEscNames, args.DestKind)
}

return name
}

func escapeName(name string, uppercaseEscNames bool, destKind constants.DestinationKind) string {
if uppercaseEscNames {
name = strings.ToUpper(name)
Expand Down
2 changes: 1 addition & 1 deletion lib/sql/escape_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestEscapeNameIfNecessary(t *testing.T) {
func TestEscapeName(t *testing.T) {
type _testCase struct {
name string
nameToEscape string
Expand Down

0 comments on commit 3825947

Please sign in to comment.