diff --git a/clients/bigquery/append.go b/clients/bigquery/append.go deleted file mode 100644 index 8e4e221b0..000000000 --- a/clients/bigquery/append.go +++ /dev/null @@ -1,12 +0,0 @@ -package bigquery - -import ( - "github.com/artie-labs/transfer/clients/shared" - "github.com/artie-labs/transfer/lib/destination/types" - "github.com/artie-labs/transfer/lib/optimization" -) - -func (s *Store) Append(tableData *optimization.TableData) error { - tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - return shared.Append(s, tableData, types.AppendOpts{TempTableID: tableID}) -} diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index f17802fc3..5d6fec060 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -41,6 +41,10 @@ type Store struct { db.Store } +func (s *Store) Append(tableData *optimization.TableData) error { + return shared.Append(s, tableData, types.AdditionalSettings{}) +} + func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error { if createTempTable { tempAlterTableArgs := ddl.AlterTableArgs{ diff --git a/clients/mssql/store.go b/clients/mssql/store.go index 73c6abb05..81265d188 100644 --- a/clients/mssql/store.go +++ b/clients/mssql/store.go @@ -44,8 +44,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error { } func (s *Store) Append(tableData *optimization.TableData) error { - tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - return shared.Append(s, tableData, types.AppendOpts{TempTableID: tableID}) + return shared.Append(s, tableData, types.AdditionalSettings{}) } // specificIdentifierFor returns a MS SQL [TableIdentifier] for a [TopicConfig] + table name. diff --git a/clients/redshift/redshift.go b/clients/redshift/redshift.go index 65720e005..4626f6cc3 100644 --- a/clients/redshift/redshift.go +++ b/clients/redshift/redshift.go @@ -30,6 +30,19 @@ type Store struct { db.Store } +func (s *Store) Append(tableData *optimization.TableData) error { + return shared.Append(s, tableData, types.AdditionalSettings{}) +} + +func (s *Store) Merge(tableData *optimization.TableData) error { + return shared.Merge(s, tableData, s.config, types.MergeOpts{ + UseMergeParts: true, + // We are adding SELECT DISTINCT here for the temporary table as an extra guardrail. + // Redshift does not enforce any row uniqueness and there could be potential LOAD errors which will cause duplicate rows to arise. + SubQueryDedupe: true, + }) +} + func (s *Store) IdentifierFor(topicConfig kafkalib.TopicConfig, table string) types.TableIdentifier { return NewTableIdentifier(topicConfig.Schema, table) } diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index 5c562034c..68d8311d2 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -15,20 +15,21 @@ import ( "github.com/artie-labs/transfer/lib/s3lib" ) -func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, _ bool) error { - // Redshift always creates a temporary table. - tempAlterTableArgs := ddl.AlterTableArgs{ - Dwh: s, - Tc: tableConfig, - TableID: tempTableID, - CreateTable: true, - TemporaryTable: true, - ColumnOp: constants.Add, - Mode: tableData.Mode(), - } +func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID types.TableIdentifier, _ types.AdditionalSettings, createTempTable bool) error { + if createTempTable { + tempAlterTableArgs := ddl.AlterTableArgs{ + Dwh: s, + Tc: tableConfig, + TableID: tempTableID, + CreateTable: true, + TemporaryTable: true, + ColumnOp: constants.Add, + Mode: tableData.Mode(), + } - if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil { - return fmt.Errorf("failed to create temp table: %w", err) + if err := tempAlterTableArgs.AlterTable(tableData.ReadOnlyInMemoryCols().GetColumns()...); err != nil { + return fmt.Errorf("failed to create temp table: %w", err) + } } fp, err := s.loadTemporaryTable(tableData, tempTableID) diff --git a/clients/redshift/writes.go b/clients/redshift/writes.go deleted file mode 100644 index fcaf08e67..000000000 --- a/clients/redshift/writes.go +++ /dev/null @@ -1,34 +0,0 @@ -package redshift - -import ( - "fmt" - - "github.com/artie-labs/transfer/clients/shared" - "github.com/artie-labs/transfer/lib/destination/types" - "github.com/artie-labs/transfer/lib/optimization" -) - -func (s *Store) Append(tableData *optimization.TableData) error { - tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) - - // Redshift is slightly different, we'll load and create the temporary table via shared.Append - // Then, we'll invoke `ALTER TABLE target APPEND FROM staging` to combine the diffs. - temporaryTableID := shared.TempTableID(tableID, tableData.TempTableSuffix()) - if err := shared.Append(s, tableData, types.AppendOpts{TempTableID: temporaryTableID}); err != nil { - return err - } - - _, err := s.Exec( - fmt.Sprintf(`ALTER TABLE %s APPEND FROM %s;`, tableID.FullyQualifiedName(), temporaryTableID.FullyQualifiedName()), - ) - return err -} - -func (s *Store) Merge(tableData *optimization.TableData) error { - return shared.Merge(s, tableData, s.config, types.MergeOpts{ - UseMergeParts: true, - // We are adding SELECT DISTINCT here for the temporary table as an extra guardrail. - // Redshift does not enforce any row uniqueness and there could be potential LOAD errors which will cause duplicate rows to arise. - SubQueryDedupe: true, - }) -} diff --git a/clients/shared/append.go b/clients/shared/append.go index 217758183..d3d9bfecb 100644 --- a/clients/shared/append.go +++ b/clients/shared/append.go @@ -11,22 +11,27 @@ import ( "github.com/artie-labs/transfer/lib/typing/columns" ) -func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AppendOpts) error { +func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, opts types.AdditionalSettings) error { if tableData.ShouldSkipUpdate() { return nil } - tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()) tableConfig, err := dwh.GetTableConfig(tableData) if err != nil { return fmt.Errorf("failed to get table config: %w", err) } // We don't care about srcKeysMissing because we don't drop columns when we append. - _, targetKeysMissing := columns.Diff(tableData.ReadOnlyInMemoryCols(), tableConfig.Columns(), - tableData.TopicConfig().SoftDelete, tableData.TopicConfig().IncludeArtieUpdatedAt, - tableData.TopicConfig().IncludeDatabaseUpdatedAt, tableData.Mode()) + _, targetKeysMissing := columns.Diff( + tableData.ReadOnlyInMemoryCols(), + tableConfig.Columns(), + tableData.TopicConfig().SoftDelete, + tableData.TopicConfig().IncludeArtieUpdatedAt, + tableData.TopicConfig().IncludeDatabaseUpdatedAt, + tableData.Mode(), + ) + tableID := dwh.IdentifierFor(tableData.TopicConfig(), tableData.Name()) createAlterTableArgs := ddl.AlterTableArgs{ Dwh: dwh, Tc: tableConfig, @@ -46,9 +51,11 @@ func Append(dwh destination.DataWarehouse, tableData *optimization.TableData, op return fmt.Errorf("failed to merge columns from destination: %w", err) } - additionalSettings := types.AdditionalSettings{ - AdditionalCopyClause: opts.AdditionalCopyClause, - } - - return dwh.PrepareTemporaryTable(tableData, tableConfig, opts.TempTableID, additionalSettings, false) + return dwh.PrepareTemporaryTable( + tableData, + tableConfig, + tableID, + opts, + false, + ) } diff --git a/clients/snowflake/snowflake_suite_test.go b/clients/snowflake/snowflake_suite_test.go index db66cf5df..77467b786 100644 --- a/clients/snowflake/snowflake_suite_test.go +++ b/clients/snowflake/snowflake_suite_test.go @@ -25,7 +25,9 @@ func (s *SnowflakeTestSuite) ResetStore() { s.fakeStageStore = &mocks.FakeStore{} stageStore := db.Store(s.fakeStageStore) var err error - s.stageStore, err = LoadSnowflake(config.Config{}, &stageStore) + s.stageStore, err = LoadSnowflake(config.Config{ + Snowflake: &config.Snowflake{}, + }, &stageStore) assert.NoError(s.T(), err) } diff --git a/clients/snowflake/writes.go b/clients/snowflake/writes.go index f5f48f0b9..b5ca06e98 100644 --- a/clients/snowflake/writes.go +++ b/clients/snowflake/writes.go @@ -25,10 +25,8 @@ func (s *Store) Append(tableData *optimization.TableData) error { } } - tableID := s.IdentifierFor(tableData.TopicConfig(), tableData.Name()) // TODO: For history mode - in the future, we could also have a separate stage name for history mode so we can enable parallel processing. - err = shared.Append(s, tableData, types.AppendOpts{ - TempTableID: tableID, + err = shared.Append(s, tableData, types.AdditionalSettings{ AdditionalCopyClause: `FILE_FORMAT = (TYPE = 'csv' FIELD_DELIMITER= '\t' FIELD_OPTIONALLY_ENCLOSED_BY='"' NULL_IF='\\N' EMPTY_FIELD_AS_NULL=FALSE) PURGE = TRUE`, }) } diff --git a/lib/array/strings.go b/lib/array/strings.go index 4d66c844b..69269c75d 100644 --- a/lib/array/strings.go +++ b/lib/array/strings.go @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) { vals = append(vals, string(bytes)) } else { - vals = append(vals, stringutil.Wrap(value, true)) + vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value))) } } diff --git a/lib/destination/types/types.go b/lib/destination/types/types.go index 79f5df7a8..9b79b62d8 100644 --- a/lib/destination/types/types.go +++ b/lib/destination/types/types.go @@ -43,13 +43,6 @@ type AdditionalSettings struct { AdditionalCopyClause string } -type AppendOpts struct { - // TempTableID - sometimes the destination requires 2 steps to append to the table (e.g. Redshift), so we'll create and load the data into a staging table - // Redshift then has a separate step after `shared.Append(...)` to merge the two tables together. - TempTableID TableIdentifier - AdditionalCopyClause string -} - type TableIdentifier interface { Table() string WithTable(table string) TableIdentifier diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index 30f3b9ebf..31f8b0c52 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -3,14 +3,12 @@ package sql import ( "fmt" "strings" - - "github.com/artie-labs/transfer/lib/stringutil" ) type Dialect interface { NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything QuoteIdentifier(identifier string) string - EscapeStruct(value any) string + EscapeStruct(value string) string } type BigQueryDialect struct{} @@ -22,8 +20,8 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf("`%s`", identifier) } -func (BigQueryDialect) EscapeStruct(value any) string { - return "JSON" + stringutil.Wrap(value, false) +func (BigQueryDialect) EscapeStruct(value string) string { + return "JSON" + QuoteLiteral(value) } type MSSQLDialect struct{} @@ -34,7 +32,7 @@ func (MSSQLDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } -func (MSSQLDialect) EscapeStruct(value any) string { +func (MSSQLDialect) EscapeStruct(value string) string { panic("not implemented") // We don't currently support backfills for MS SQL. } @@ -47,8 +45,8 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) } -func (RedshiftDialect) EscapeStruct(value any) string { - return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false)) +func (RedshiftDialect) EscapeStruct(value string) string { + return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(value)) } type SnowflakeDialect struct{} @@ -61,6 +59,6 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, strings.ToUpper(identifier)) } -func (SnowflakeDialect) EscapeStruct(value any) string { - return stringutil.Wrap(value, false) +func (SnowflakeDialect) EscapeStruct(value string) string { + return QuoteLiteral(value) } diff --git a/lib/sql/util.go b/lib/sql/util.go new file mode 100644 index 000000000..9a8150f89 --- /dev/null +++ b/lib/sql/util.go @@ -0,0 +1,15 @@ +package sql + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/stringutil" +) + +// QuoteLiteral wraps a string with single quotes so that it can be used in a SQL query. +// If there are backslashes in the string, then they will be escaped to [\\]. +// After escaping backslashes, any remaining single quotes will be replaced with [\']. +func QuoteLiteral(value string) string { + return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`)) +} diff --git a/lib/sql/util_test.go b/lib/sql/util_test.go new file mode 100644 index 000000000..89ea11320 --- /dev/null +++ b/lib/sql/util_test.go @@ -0,0 +1,40 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQuoteLiteral(t *testing.T) { + testCases := []struct { + name string + colVal string + expected string + }{ + { + name: "string", + colVal: "hello", + expected: "'hello'", + }, + { + name: "string that requires escaping", + colVal: "bobby o'reilly", + expected: `'bobby o\'reilly'`, + }, + { + name: "string with line breaks", + colVal: "line1 \n line 2", + expected: "'line1 \n line 2'", + }, + { + name: "string with existing backslash", + colVal: `hello \ there \ hh`, + expected: `'hello \\ there \\ hh'`, + }, + } + + for _, testCase := range testCases { + assert.Equal(t, testCase.expected, QuoteLiteral(testCase.colVal), testCase.name) + } +} diff --git a/lib/stringutil/strings.go b/lib/stringutil/strings.go index c5083f884..fbff8b649 100644 --- a/lib/stringutil/strings.go +++ b/lib/stringutil/strings.go @@ -1,7 +1,6 @@ package stringutil import ( - "fmt" "math/rand" "strings" ) @@ -26,16 +25,8 @@ func Override(vals ...string) string { return retVal } -func Wrap(colVal any, noQuotes bool) string { - colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`) - // The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \' - if noQuotes { - return fmt.Sprint(colVal) - } - - // When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles. - // However, if there are no quote wrapping, we should not need to escape. - return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`)) +func EscapeBackslashes(value string) string { + return strings.ReplaceAll(value, `\`, `\\`) } func Empty(vals ...string) bool { diff --git a/lib/stringutil/strings_test.go b/lib/stringutil/strings_test.go index 344e5dd69..b2d728a02 100644 --- a/lib/stringutil/strings_test.go +++ b/lib/stringutil/strings_test.go @@ -68,69 +68,36 @@ func TestOverride(t *testing.T) { } } -func TestWrap(t *testing.T) { - type _testCase struct { +func TestEscapeBackslashes(t *testing.T) { + testCases := []struct { name string - colVal any - noQuotes bool + colVal string expectedString string - } - - testCases := []_testCase{ + }{ { name: "string", colVal: "hello", - expectedString: "'hello'", - }, - { - name: "string (no quotes)", - colVal: "hello", - noQuotes: true, expectedString: "hello", }, { - name: "string (no quotes)", + name: "string", colVal: "bobby o'reilly", - noQuotes: true, expectedString: "bobby o'reilly", }, - { - name: "string that requires escaping", - colVal: "bobby o'reilly", - expectedString: `'bobby o\'reilly'`, - }, - { - name: "string that requires escaping (no quotes)", - colVal: "bobby o'reilly", - expectedString: `bobby o'reilly`, - noQuotes: true, - }, { name: "string with line breaks", colVal: "line1 \n line 2", - expectedString: "'line1 \n line 2'", - }, - { - name: "string with line breaks (no quotes)", - colVal: "line1 \n line 2", expectedString: "line1 \n line 2", - noQuotes: true, }, { name: "string with existing backslash", colVal: `hello \ there \ hh`, - expectedString: `'hello \\ there \\ hh'`, - }, - { - name: "string with existing backslash (no quotes)", - colVal: `hello \ there \ hh`, expectedString: `hello \\ there \\ hh`, - noQuotes: true, }, } for _, testCase := range testCases { - assert.Equal(t, testCase.expectedString, Wrap(testCase.colVal, testCase.noQuotes), testCase.name) + assert.Equal(t, testCase.expectedString, EscapeBackslashes(testCase.colVal), testCase.name) } } diff --git a/lib/typing/columns/default.go b/lib/typing/columns/default.go index a5155777d..6624c4278 100644 --- a/lib/typing/columns/default.go +++ b/lib/typing/columns/default.go @@ -7,7 +7,6 @@ import ( "github.com/artie-labs/transfer/lib/typing/ext" - "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/decimal" ) @@ -23,7 +22,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.Kind { case typing.Struct.Kind, typing.Array.Kind: - return dialect.EscapeStruct(c.defaultValue), nil + return dialect.EscapeStruct(fmt.Sprint(c.defaultValue)), nil case typing.ETime.Kind: if c.KindDetails.ExtendedTimeDetails == nil { return nil, fmt.Errorf("column kind details for extended time is nil") @@ -36,9 +35,9 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.ExtendedTimeDetails.Type { case ext.TimeKindType: - return stringutil.Wrap(extTime.String(ext.PostgresTimeFormatNoTZ), false), nil + return sql.QuoteLiteral(extTime.String(ext.PostgresTimeFormatNoTZ)), nil default: - return stringutil.Wrap(extTime.String(c.KindDetails.ExtendedTimeDetails.Format), false), nil + return sql.QuoteLiteral(extTime.String(c.KindDetails.ExtendedTimeDetails.Format)), nil } case typing.EDecimal.Kind: val, isOk := c.defaultValue.(*decimal.Decimal) @@ -48,7 +47,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) return val.Value(), nil case typing.String.Kind: - return stringutil.Wrap(c.defaultValue, false), nil + return sql.QuoteLiteral(fmt.Sprint(c.defaultValue)), nil } return c.defaultValue, nil diff --git a/lib/typing/values/string.go b/lib/typing/values/string.go index 9597717f0..60502fada 100644 --- a/lib/typing/values/string.go +++ b/lib/typing/values/string.go @@ -56,7 +56,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) ( return string(colValBytes), nil } - return stringutil.Wrap(colVal, true), nil + return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil case typing.Struct.Kind: if colKind.KindDetails == typing.Struct { if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) {