Skip to content

Commit

Permalink
Merge branch 'master' into nv/always-uppercase-escaped-snowflake-names
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
2 parents 3ca6917 + 1306702 commit 9202df2
Show file tree
Hide file tree
Showing 18 changed files with 130 additions and 149 deletions.
12 changes: 0 additions & 12 deletions clients/bigquery/append.go

This file was deleted.

4 changes: 4 additions & 0 deletions clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ type Store struct {
db.Store
}

func (s *Store) Append(tableData *optimization.TableData) error {
return shared.Append(s, tableData, types.AdditionalSettings{})
}

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Expand Down
3 changes: 1 addition & 2 deletions clients/mssql/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
}

func (s *Store) Append(tableData *optimization.TableData) error {
tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name())
return shared.Append(s, tableData, types.AppendOpts{TempTableID: tableID})
return shared.Append(s, tableData, types.AdditionalSettings{})
}

// specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name.
Expand Down
13 changes: 13 additions & 0 deletions clients/redshift/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ type Store struct {
db.Store
}

func (s *Store) Append(tableData *optimization.TableData) error {
return shared.Append(s, tableData, types.AdditionalSettings{})
}

func (s *Store) Merge(tableData *optimization.TableData) error {
return shared.Merge(s, tableData, s.config, types.MergeOpts{
UseMergeParts: true,
// We are adding SELECT DISTINCT here for the temporary table as an extra guardrail.
// Redshift does not enforce any row uniqueness and there could be potential LOAD errors which will cause duplicate rows to arise.
SubQueryDedupe: true,
})
}

func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier {
return NewTableIdentifier(topicConfig.Schema, table)
}
Expand Down
27 changes: 14 additions & 13 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@ import (
"github.com/artie-labs/transfer/lib/s3lib"
)

func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, _ bool) error {
// Redshift always creates a temporary table.
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}
func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error {
if createTempTable {
tempAlterTableArgs := ddl.AlterTableArgs{
Dwh: s,
Tc: tableConfig,
TableID: tempTableID,
CreateTable: true,
TemporaryTable: true,
ColumnOp: constants.Add,
Mode: tableData.Mode(),
}

if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil {
return fmt.Errorf("failed to create temp table: %w", err)
}
}

fp, err := s.loadTemporaryTable(tableData, tempTableID)
Expand Down
34 changes: 0 additions & 34 deletions clients/redshift/writes.go

This file was deleted.

27 changes: 17 additions & 10 deletions clients/shared/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,27 @@ import (
"github.com/artie-labs/transfer/lib/typing/columns"
)

func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AppendOpts) error {
func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AdditionalSettings) error {
if tableData.ShouldSkipUpdate() {
return nil
}

tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
tableConfig, err := dwh.GetTableConfig(tableData)
if err != nil {
return fmt.Errorf("failed to get table config: %w", err)
}

// We don't care about srcKeysMissing because we don't drop columns when we append.
_, targetKeysMissing := columns.Diff(tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(),
tableData.TopicConfig().SoftDelete, tableData.TopicConfig().IncludeArtieUpdatedAt,
tableData.TopicConfig().IncludeDatabaseUpdatedAt, tableData.Mode())
_, targetKeysMissing := columns.Diff(
tableData.ReadOnlyInMemoryCols(),
tableConfig.Columns(),
tableData.TopicConfig().SoftDelete,
tableData.TopicConfig().IncludeArtieUpdatedAt,
tableData.TopicConfig().IncludeDatabaseUpdatedAt,
tableData.Mode(),
)

tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name())
createAlterTableArgs := ddl.AlterTableArgs{
Dwh: dwh,
Tc: tableConfig,
Expand All @@ -46,9 +51,11 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op
return fmt.Errorf("failed to merge columns from destination: %w", err)
}

additionalSettings := types.AdditionalSettings{
AdditionalCopyClause: opts.AdditionalCopyClause,
}

return dwh.PrepareTemporaryTable(tableData, tableConfig, opts.TempTableID, additionalSettings, false)
return dwh.PrepareTemporaryTable(
tableData,
tableConfig,
tableID,
opts,
false,
)
}
4 changes: 3 additions & 1 deletion clients/snowflake/snowflake_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ func (s *SnowflakeTestSuite) ResetStore() {
s.fakeStageStore = &mocks.FakeStore{}
stageStore := db.Store(s.fakeStageStore)
var err error
s.stageStore, err = LoadSnowflake(config.Config{}, &stageStore)
s.stageStore, err = LoadSnowflake(config.Config{
Snowflake: &config.Snowflake{},
}, &stageStore)
assert.NoError(s.T(), err)
}

Expand Down
4 changes: 1 addition & 3 deletions clients/snowflake/writes.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ func (s *Store) Append(tableData *optimization.TableData) error {
}
}

tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name())
// TODO: For history mode - in the future, we could also have a separate stage name for history mode so we can enable parallel processing.
err = shared.Append(s, tableData, types.AppendOpts{
TempTableID: tableID,
err = shared.Append(s, tableData, types.AdditionalSettings{
AdditionalCopyClause: `FILE_FORMAT = (TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) PURGE = TRUE`,
})
}
Expand Down
2 changes: 1 addition & 1 deletion lib/array/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) {

vals = append(vals, string(bytes))
} else {
vals = append(vals, stringutil.Wrap(value, true))
vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value)))
}
}

Expand Down
7 changes: 0 additions & 7 deletions lib/destination/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,6 @@ type AdditionalSettings struct {
AdditionalCopyClause string
}

type AppendOpts struct {
// TempTableID - sometimes the destination requires 2 steps to append to the table (e.g. Redshift), so we'll create and load the data into a staging table
// Redshift then has a separate step after `shared.Append(...)` to merge the two tables together.
TempTableID TableIdentifier
AdditionalCopyClause string
}

type TableIdentifier interface {
Table() string
WithTable(table string) TableIdentifier
Expand Down
18 changes: 8 additions & 10 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ package sql
import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/stringutil"
)

type Dialect interface {
NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything
QuoteIdentifier(identifier string) string
EscapeStruct(value any) string
EscapeStruct(value string) string
}

type BigQueryDialect struct{}
Expand All @@ -22,8 +20,8 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf("`%s`", identifier)
}

func (BigQueryDialect) EscapeStruct(value any) string {
return "JSON" + stringutil.Wrap(value, false)
func (BigQueryDialect) EscapeStruct(value string) string {
return "JSON" + QuoteLiteral(value)
}

type MSSQLDialect struct{}
Expand All @@ -34,7 +32,7 @@ func (MSSQLDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

func (MSSQLDialect) EscapeStruct(value any) string {
func (MSSQLDialect) EscapeStruct(value string) string {
panic("not implemented") // We don't currently support backfills for MS SQL.
}

Expand All @@ -47,8 +45,8 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
}

func (RedshiftDialect) EscapeStruct(value any) string {
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false))
func (RedshiftDialect) EscapeStruct(value string) string {
return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(value))
}

type SnowflakeDialect struct{}
Expand All @@ -61,6 +59,6 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier))
}

func (SnowflakeDialect) EscapeStruct(value any) string {
return stringutil.Wrap(value, false)
func (SnowflakeDialect) EscapeStruct(value string) string {
return QuoteLiteral(value)
}
15 changes: 15 additions & 0 deletions lib/sql/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package sql

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/stringutil"
)

// QuoteLiteral wraps a string with single quotes so that it can be used in a SQL query.
// If there are backslashes in the string, then they will be escaped to [\\].
// After escaping backslashes, any remaining single quotes will be replaced with [\'].
func QuoteLiteral(value string) string {
return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`))
}
40 changes: 40 additions & 0 deletions lib/sql/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestQuoteLiteral(t *testing.T) {
testCases := []struct {
name string
colVal string
expected string
}{
{
name: "string",
colVal: "hello",
expected: "'hello'",
},
{
name: "string that requires escaping",
colVal: "bobby o'reilly",
expected: `'bobby o\'reilly'`,
},
{
name: "string with line breaks",
colVal: "line1 \n line 2",
expected: "'line1 \n line 2'",
},
{
name: "string with existing backslash",
colVal: `hello \ there \ hh`,
expected: `'hello \\ there \\ hh'`,
},
}

for _, testCase := range testCases {
assert.Equal(t, testCase.expected, QuoteLiteral(testCase.colVal), testCase.name)
}
}
13 changes: 2 additions & 11 deletions lib/stringutil/strings.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package stringutil

import (
"fmt"
"math/rand"
"strings"
)
Expand All @@ -26,16 +25,8 @@ func Override(vals ...string) string {
return retVal
}

func Wrap(colVal any, noQuotes bool) string {
colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`)
// The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \'
if noQuotes {
return fmt.Sprint(colVal)
}

// When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles.
// However, if there are no quote wrapping, we should not need to escape.
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`))
func EscapeBackslashes(value string) string {
return strings.ReplaceAll(value, `\`, `\\`)
}

func Empty(vals ...string) bool {
Expand Down
Loading

0 comments on commit 9202df2

Please sign in to comment.