Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[snowflake] Pass TableIdentifier to addPrefixToTableName #480

Merged
merged 9 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clients/shared/sweep.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type GetQueryFunc func(dbAndSchemaPair kafkalib.DatabaseSchemaPair) (string, []a

func Sweep(dwh destination.DataWarehouse, topicConfigs []*kafkalib.TopicConfig, getQueryFunc GetQueryFunc) error {
slog.Info("Looking to see if there are any dangling artie temporary tables to delete...")
// TODO: Rewrite this to use [DataWarehouse.IdentifierFor]
dbAndSchemaPairs := kafkalib.GetUniqueDatabaseAndSchema(topicConfigs)
for _, dbAndSchemaPair := range dbAndSchemaPairs {
query, args := getQueryFunc(dbAndSchemaPair)
Expand All @@ -29,6 +30,7 @@ func Sweep(dwh destination.DataWarehouse, topicConfigs []*kafkalib.TopicConfig,
}

if ddl.ShouldDeleteFromName(tableName) {
// TODO: Rewrite this to pass a [types.TableIdentifiers] to [DropTemporaryTable]
err = ddl.DropTemporaryTable(dwh, fmt.Sprintf("%s.%s.%s", dbAndSchemaPair.Database, tableSchema, tableName), true)
if err != nil {
return err
Expand Down
9 changes: 4 additions & 5 deletions clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ func castColValStaging(colVal any, colKind columns.Column, additionalDateFmts []
}

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error {
tempTableName := tempTableID.FullyQualifiedName()

if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Expand Down Expand Up @@ -81,17 +79,18 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}()

// Upload the CSV file to Snowflake
if _, err = s.Exec(fmt.Sprintf("PUT file://%s @%s AUTO_COMPRESS=TRUE", fp, addPrefixToTableName(tempTableName, "%"))); err != nil {
if _, err = s.Exec(fmt.Sprintf("PUT file://%s @%s AUTO_COMPRESS=TRUE", fp, addPrefixToTableName(tempTableID, "%"))); err != nil {
return fmt.Errorf("failed to run PUT for temporary table: %w", err)
}

// COPY the CSV file (in Snowflake) into a table
copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)",
tempTableName, strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &sql.NameArgs{
tempTableID.FullyQualifiedName(),
strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &sql.NameArgs{
Escape: true,
DestKind: s.Label(),
}), ","),
escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableName, "%"))
escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%"))

if additionalSettings.AdditionalCopyClause != "" {
copyCommand += " " + additionalSettings.AdditionalCopyClause
Expand Down
2 changes: 1 addition & 1 deletion clients/snowflake/staging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (s *SnowflakeTestSuite) TestPrepareTempTable() {
`CREATE TABLE IF NOT EXISTS %s (user_id string,first_name string,last_name string,dusty string) STAGE_COPY_OPTIONS = ( PURGE = TRUE ) STAGE_FILE_FORMAT = ( TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE)`, tempTableName)
containsPrefix := strings.HasPrefix(createQuery, prefixQuery)
assert.True(s.T(), containsPrefix, fmt.Sprintf("createQuery:%v, prefixQuery:%s", createQuery, prefixQuery))
resourceName := addPrefixToTableName(tempTableName, "%")
resourceName := addPrefixToTableName(tempTableID, "%")
// Second call is a PUT
putQuery, _ := s.fakeStageStore.ExecArgsForCall(1)
assert.Contains(s.T(), putQuery, "PUT file://", putQuery)
Expand Down
19 changes: 6 additions & 13 deletions clients/snowflake/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,15 @@ import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/typing/columns"

"github.com/artie-labs/transfer/lib/destination/types"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/columns"
)

// addPrefixToTableName will take the fully qualified table name and add a prefix in front of the table
// This is necessary for `PUT` commands. The fq name looks like <namespace>.<tableName>
// Namespace may contain both database and schema.
func addPrefixToTableName(fqTableName string, prefix string) string {
tableParts := strings.Split(fqTableName, ".")
if len(tableParts) == 1 {
return prefix + fqTableName
}

return fmt.Sprintf("%s.%s%s",
strings.Join(tableParts[0:len(tableParts)-1], "."), prefix, tableParts[len(tableParts)-1])
// addPrefixToTableName will take a [types.TableIdentifier] and add a prefix in front of the table.
// This is necessary for `PUT` commands.
func addPrefixToTableName(tableID types.TableIdentifier, prefix string) string {
return tableID.WithTable(prefix + tableID.Table()).FullyQualifiedName()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer. Nice one

}

// escapeColumns will take columns, filter out invalid, escape and return them in ordered received.
Expand Down
18 changes: 9 additions & 9 deletions clients/snowflake/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,35 @@ func TestAddPrefixToTableName(t *testing.T) {
const prefix = "%"
type _testCase struct {
name string
fqTableName string
tableID TableIdentifier
expectedFqTableName string
}

testCases := []_testCase{
{
name: "happy path",
fqTableName: "database.schema.tableName",
tableID: NewTableIdentifier("database", "schema", "tableName", true),
expectedFqTableName: "database.schema.%tableName",
},
{
name: "tableName only",
fqTableName: "orders",
expectedFqTableName: "%orders",
tableID: NewTableIdentifier("", "", "orders", true),
expectedFqTableName: "..%orders",
},
{
name: "schema and tableName only",
fqTableName: "public.orders",
expectedFqTableName: "public.%orders",
tableID: NewTableIdentifier("", "public", "orders", true),
expectedFqTableName: ".public.%orders",
},
{
name: "db and tableName only",
fqTableName: "db.tableName",
expectedFqTableName: "db.%tableName",
tableID: NewTableIdentifier("db", "", "tableName", true),
expectedFqTableName: "db..%tableName",
},
}

for _, testCase := range testCases {
assert.Equal(t, addPrefixToTableName(testCase.fqTableName, prefix), testCase.expectedFqTableName, testCase.name)
assert.Equal(t, testCase.expectedFqTableName, addPrefixToTableName(testCase.tableID, prefix), testCase.name)
}
}

Expand Down