diff --git a/lib/array/strings.go b/lib/array/strings.go index 4d66c844b..69269c75d 100644 --- a/lib/array/strings.go +++ b/lib/array/strings.go @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) { vals = append(vals, string(bytes)) } else { - vals = append(vals, stringutil.Wrap(value, true)) + vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value))) } } diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index ae39b254e..399f13271 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/artie-labs/transfer/lib/config/constants" - "github.com/artie-labs/transfer/lib/stringutil" ) type Dialect interface { @@ -27,7 +26,7 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string { } func (BigQueryDialect) EscapeStruct(value any) string { - return "JSON" + stringutil.Wrap(value, false) + return "JSON" + QuoteLiteral(fmt.Sprint(value)) } type MSSQLDialect struct{} @@ -52,7 +51,7 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { } func (RedshiftDialect) EscapeStruct(value any) string { - return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false)) + return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(fmt.Sprint(value))) } type SnowflakeDialect struct { @@ -91,5 +90,5 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { } func (SnowflakeDialect) EscapeStruct(value any) string { - return stringutil.Wrap(value, false) + return QuoteLiteral(fmt.Sprint(value)) } diff --git a/lib/sql/escape_test.go b/lib/sql/escape_test.go new file mode 100644 index 000000000..e4b317b49 --- /dev/null +++ b/lib/sql/escape_test.go @@ -0,0 +1 @@ +package sql diff --git a/lib/sql/util.go b/lib/sql/util.go new file mode 100644 index 000000000..e1bf77ed1 --- /dev/null +++ b/lib/sql/util.go @@ -0,0 +1,14 @@ +package sql + +import ( + "fmt" + "strings" + + "github.com/artie-labs/transfer/lib/stringutil" +) + +func QuoteLiteral(value string) string { + // When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles. + // However, if there are no quote wrapping, we should not need to escape. + return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`)) +} diff --git a/lib/sql/util_test.go b/lib/sql/util_test.go new file mode 100644 index 000000000..1b2f24b80 --- /dev/null +++ b/lib/sql/util_test.go @@ -0,0 +1,42 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestQuoteLiteral(t *testing.T) { + type _testCase struct { + name string + colVal string + expectedString string + } + + testCases := []_testCase{ + { + name: "string", + colVal: "hello", + expectedString: "'hello'", + }, + { + name: "string that requires escaping", + colVal: "bobby o'reilly", + expectedString: `'bobby o\'reilly'`, + }, + { + name: "string with line breaks", + colVal: "line1 \n line 2", + expectedString: "'line1 \n line 2'", + }, + { + name: "string with existing backslash", + colVal: `hello \ there \ hh`, + expectedString: `'hello \\ there \\ hh'`, + }, + } + + for _, testCase := range testCases { + assert.Equal(t, testCase.expectedString, QuoteLiteral(testCase.colVal), testCase.name) + } +} diff --git a/lib/stringutil/strings.go b/lib/stringutil/strings.go index c5083f884..fbff8b649 100644 --- a/lib/stringutil/strings.go +++ b/lib/stringutil/strings.go @@ -1,7 +1,6 @@ package stringutil import ( - "fmt" "math/rand" "strings" ) @@ -26,16 +25,8 @@ func Override(vals ...string) string { return retVal } -func Wrap(colVal any, noQuotes bool) string { - colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`) - // The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \' - if noQuotes { - return fmt.Sprint(colVal) - } - - // When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles. - // However, if there are no quote wrapping, we should not need to escape. - return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`)) +func EscapeBackslashes(value string) string { + return strings.ReplaceAll(value, `\`, `\\`) } func Empty(vals ...string) bool { diff --git a/lib/stringutil/strings_test.go b/lib/stringutil/strings_test.go index 344e5dd69..a953d81d8 100644 --- a/lib/stringutil/strings_test.go +++ b/lib/stringutil/strings_test.go @@ -68,11 +68,10 @@ func TestOverride(t *testing.T) { } } -func TestWrap(t *testing.T) { +func TestEscapeBackslashes(t *testing.T) { type _testCase struct { name string - colVal any - noQuotes bool + colVal string expectedString string } @@ -80,57 +79,27 @@ func TestWrap(t *testing.T) { { name: "string", colVal: "hello", - expectedString: "'hello'", - }, - { - name: "string (no quotes)", - colVal: "hello", - noQuotes: true, expectedString: "hello", }, { - name: "string (no quotes)", + name: "string", colVal: "bobby o'reilly", - noQuotes: true, expectedString: "bobby o'reilly", }, - { - name: "string that requires escaping", - colVal: "bobby o'reilly", - expectedString: `'bobby o\'reilly'`, - }, - { - name: "string that requires escaping (no quotes)", - colVal: "bobby o'reilly", - expectedString: `bobby o'reilly`, - noQuotes: true, - }, { name: "string with line breaks", colVal: "line1 \n line 2", - expectedString: "'line1 \n line 2'", - }, - { - name: "string with line breaks (no quotes)", - colVal: "line1 \n line 2", expectedString: "line1 \n line 2", - noQuotes: true, }, { name: "string with existing backslash", colVal: `hello \ there \ hh`, - expectedString: `'hello \\ there \\ hh'`, - }, - { - name: "string with existing backslash (no quotes)", - colVal: `hello \ there \ hh`, expectedString: `hello \\ there \\ hh`, - noQuotes: true, }, } for _, testCase := range testCases { - assert.Equal(t, testCase.expectedString, Wrap(testCase.colVal, testCase.noQuotes), testCase.name) + assert.Equal(t, testCase.expectedString, EscapeBackslashes(testCase.colVal), testCase.name) } } diff --git a/lib/typing/columns/default.go b/lib/typing/columns/default.go index a5155777d..783c2cd83 100644 --- a/lib/typing/columns/default.go +++ b/lib/typing/columns/default.go @@ -7,7 +7,6 @@ import ( "github.com/artie-labs/transfer/lib/typing/ext" - "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" "github.com/artie-labs/transfer/lib/typing/decimal" ) @@ -36,9 +35,9 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.ExtendedTimeDetails.Type { case ext.TimeKindType: - return stringutil.Wrap(extTime.String(ext.PostgresTimeFormatNoTZ), false), nil + return sql.QuoteLiteral(extTime.String(ext.PostgresTimeFormatNoTZ)), nil default: - return stringutil.Wrap(extTime.String(c.KindDetails.ExtendedTimeDetails.Format), false), nil + return sql.QuoteLiteral(extTime.String(c.KindDetails.ExtendedTimeDetails.Format)), nil } case typing.EDecimal.Kind: val, isOk := c.defaultValue.(*decimal.Decimal) @@ -48,7 +47,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) return val.Value(), nil case typing.String.Kind: - return stringutil.Wrap(c.defaultValue, false), nil + return sql.QuoteLiteral(fmt.Sprint(c.defaultValue)), nil } return c.defaultValue, nil diff --git a/lib/typing/values/string.go b/lib/typing/values/string.go index 9597717f0..60502fada 100644 --- a/lib/typing/values/string.go +++ b/lib/typing/values/string.go @@ -56,7 +56,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) ( return string(colValBytes), nil } - return stringutil.Wrap(colVal, true), nil + return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil case typing.Struct.Kind: if colKind.KindDetails == typing.Struct { if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) {