From efec8fdaaf93229b2b065c0f1f8085804026b243 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 1 May 2024 18:42:20 -0700 Subject: [PATCH] Split `stringutil.Wrap` into two functions (#538) --- lib/array/strings.go | 2 +- lib/sql/dialect.go | 17 ++++++------- lib/sql/util.go | 15 ++++++++++++ lib/sql/util_test.go | 40 ++++++++++++++++++++++++++++++ lib/stringutil/strings.go | 13 ++-------- lib/stringutil/strings_test.go | 45 +++++----------------------------- lib/typing/columns/default.go | 9 +++---- lib/typing/values/string.go | 2 +- 8 files changed, 77 insertions(+), 66 deletions(-) create mode 100644 lib/sql/util.go create mode 100644 lib/sql/util_test.go diff --git a/lib/array/strings.go b/lib/array/strings.go index 4d66c844b..69269c75d 100644 --- a/lib/array/strings.go +++ b/lib/array/strings.go @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) { vals = append(vals, string(bytes)) } else { - vals = append(vals, stringutil.Wrap(value, true)) + vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value))) } } diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index ae39b254e..b19de75b1 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -8,13 +8,12 @@ import ( "strings" "github.com/artie-labs/transfer/lib/config/constants" - "github.com/artie-labs/transfer/lib/stringutil" ) type Dialect interface { NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything QuoteIdentifier(identifier string) string - EscapeStruct(value any) string + EscapeStruct(value string) string } type BigQueryDialect struct{} @@ -26,8 +25,8 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf("`%s`", identifier) } -func (BigQueryDialect) EscapeStruct(value any) string { - return "JSON" + stringutil.Wrap(value, false) +func (BigQueryDialect) EscapeStruct(value string) string { + return "JSON" + QuoteLiteral(value) } type MSSQLDialect struct{} @@ -38,7 +37,7 @@ func (MSSQLDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } -func (MSSQLDialect) EscapeStruct(value any) string { +func (MSSQLDialect) EscapeStruct(value string) string { panic("not implemented") // We don't currently support backfills for MS SQL. } @@ -51,8 +50,8 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) } -func (RedshiftDialect) EscapeStruct(value any) string { - return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false)) +func (RedshiftDialect) EscapeStruct(value string) string { + return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(value)) } type SnowflakeDialect struct { @@ -90,6 +89,6 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } -func (SnowflakeDialect) EscapeStruct(value any) string { - return stringutil.Wrap(value, false) +func (SnowflakeDialect) EscapeStruct(value string) string { + return QuoteLiteral(value) } diff --git a/lib/sql/util.go b/lib/sql/util.go new file mode 100644 index 000000000..9a8150f89 --- /dev/null +++ b/lib/sql/util.go @@ -0,0 +1,15 @@ +package sql + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/stringutil" +) + +// QuoteLiteral wraps a string with single quotes so that it can be used in a SQL query. +// If there are backslashes in the string, then they will be escaped to [\\]. +// After escaping backslashes, any remaining single quotes will be replaced with [\']. +func QuoteLiteral(value string) string { + return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`)) +} diff --git a/lib/sql/util_test.go b/lib/sql/util_test.go new file mode 100644 index 000000000..89ea11320 --- /dev/null +++ b/lib/sql/util_test.go @@ -0,0 +1,40 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQuoteLiteral(t *testing.T) { + testCases := []struct { + name string + colVal string + expected string + }{ + { + name: "string", + colVal: "hello", + expected: "'hello'", + }, + { + name: "string that requires escaping", + colVal: "bobby o'reilly", + expected: `'bobby o\'reilly'`, + }, + { + name: "string with line breaks", + colVal: "line1 \n line 2", + expected: "'line1 \n line 2'", + }, + { + name: "string with existing backslash", + colVal: `hello \ there \ hh`, + expected: `'hello \\ there \\ hh'`, + }, + } + + for _, testCase := range testCases { + assert.Equal(t, testCase.expected, QuoteLiteral(testCase.colVal), testCase.name) + } +} diff --git a/lib/stringutil/strings.go b/lib/stringutil/strings.go index c5083f884..fbff8b649 100644 --- a/lib/stringutil/strings.go +++ b/lib/stringutil/strings.go @@ -1,7 +1,6 @@ package stringutil import ( - "fmt" "math/rand" "strings" ) @@ -26,16 +25,8 @@ func Override(vals ...string) string { return retVal } -func Wrap(colVal any, noQuotes bool) string { - colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`) - // The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \' - if noQuotes { - return fmt.Sprint(colVal) - } - - // When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles. - // However, if there are no quote wrapping, we should not need to escape. - return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`)) +func EscapeBackslashes(value string) string { + return strings.ReplaceAll(value, `\`, `\\`) } func Empty(vals ...string) bool { diff --git a/lib/stringutil/strings_test.go b/lib/stringutil/strings_test.go index 344e5dd69..b2d728a02 100644 --- a/lib/stringutil/strings_test.go +++ b/lib/stringutil/strings_test.go @@ -68,69 +68,36 @@ func TestOverride(t *testing.T) { } } -func TestWrap(t *testing.T) { - type _testCase struct { +func TestEscapeBackslashes(t *testing.T) { + testCases := []struct { name string - colVal any - noQuotes bool + colVal string expectedString string - } - - testCases := []_testCase{ + }{ { name: "string", colVal: "hello", - expectedString: "'hello'", - }, - { - name: "string (no quotes)", - colVal: "hello", - noQuotes: true, expectedString: "hello", }, { - name: "string (no quotes)", + name: "string", colVal: "bobby o'reilly", - noQuotes: true, expectedString: "bobby o'reilly", }, - { - name: "string that requires escaping", - colVal: "bobby o'reilly", - expectedString: `'bobby o\'reilly'`, - }, - { - name: "string that requires escaping (no quotes)", - colVal: "bobby o'reilly", - expectedString: `bobby o'reilly`, - noQuotes: true, - }, { name: "string with line breaks", colVal: "line1 \n line 2", - expectedString: "'line1 \n line 2'", - }, - { - name: "string with line breaks (no quotes)", - colVal: "line1 \n line 2", expectedString: "line1 \n line 2", - noQuotes: true, }, { name: "string with existing backslash", colVal: `hello \ there \ hh`, - expectedString: `'hello \\ there \\ hh'`, - }, - { - name: "string with existing backslash (no quotes)", - colVal: `hello \ there \ hh`, expectedString: `hello \\ there \\ hh`, - noQuotes: true, }, } for _, testCase := range testCases { - assert.Equal(t, testCase.expectedString, Wrap(testCase.colVal, testCase.noQuotes), testCase.name) + assert.Equal(t, testCase.expectedString, EscapeBackslashes(testCase.colVal), testCase.name) } } diff --git a/lib/typing/columns/default.go b/lib/typing/columns/default.go index a5155777d..6624c4278 100644 --- a/lib/typing/columns/default.go +++ b/lib/typing/columns/default.go @@ -7,7 +7,6 @@ import ( "github.com/artie-labs/transfer/lib/typing/ext" - "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/decimal" ) @@ -23,7 +22,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.Kind { case typing.Struct.Kind, typing.Array.Kind: - return dialect.EscapeStruct(c.defaultValue), nil + return dialect.EscapeStruct(fmt.Sprint(c.defaultValue)), nil case typing.ETime.Kind: if c.KindDetails.ExtendedTimeDetails == nil { return nil, fmt.Errorf("column kind details for extended time is nil") @@ -36,9 +35,9 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.ExtendedTimeDetails.Type { case ext.TimeKindType: - return stringutil.Wrap(extTime.String(ext.PostgresTimeFormatNoTZ), false), nil + return sql.QuoteLiteral(extTime.String(ext.PostgresTimeFormatNoTZ)), nil default: - return stringutil.Wrap(extTime.String(c.KindDetails.ExtendedTimeDetails.Format), false), nil + return sql.QuoteLiteral(extTime.String(c.KindDetails.ExtendedTimeDetails.Format)), nil } case typing.EDecimal.Kind: val, isOk := c.defaultValue.(*decimal.Decimal) @@ -48,7 +47,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) return val.Value(), nil case typing.String.Kind: - return stringutil.Wrap(c.defaultValue, false), nil + return sql.QuoteLiteral(fmt.Sprint(c.defaultValue)), nil } return c.defaultValue, nil diff --git a/lib/typing/values/string.go b/lib/typing/values/string.go index 9597717f0..60502fada 100644 --- a/lib/typing/values/string.go +++ b/lib/typing/values/string.go @@ -56,7 +56,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) ( return string(colValBytes), nil } - return stringutil.Wrap(colVal, true), nil + return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil case typing.Struct.Kind: if colKind.KindDetails == typing.Struct { if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) {