Skip to content

Commit

Permalink
Split stringutil.Wrap into two functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie committed May 2, 2024
1 parent 4153191 commit 6371fdb
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 56 deletions.
2 changes: 1 addition & 1 deletion lib/array/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) {

vals = append(vals, string(bytes))
} else {
vals = append(vals, stringutil.Wrap(value, true))
vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value)))
}
}

Expand Down
7 changes: 3 additions & 4 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/stringutil"
)

type Dialect interface {
Expand All @@ -27,7 +26,7 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string {
}

func (BigQueryDialect) EscapeStruct(value any) string {
return "JSON" + stringutil.Wrap(value, false)
return "JSON" + QuoteLiteral(fmt.Sprint(value))
}

type MSSQLDialect struct{}
Expand All @@ -52,7 +51,7 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
}

func (RedshiftDialect) EscapeStruct(value any) string {
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false))
return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(fmt.Sprint(value)))
}

type SnowflakeDialect struct {
Expand Down Expand Up @@ -91,5 +90,5 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
}

func (SnowflakeDialect) EscapeStruct(value any) string {
return stringutil.Wrap(value, false)
return QuoteLiteral(fmt.Sprint(value))
}
1 change: 1 addition & 0 deletions lib/sql/escape_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package sql
14 changes: 14 additions & 0 deletions lib/sql/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package sql

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/stringutil"
)

func QuoteLiteral(value string) string {
// When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles.
// However, if there are no quote wrapping, we should not need to escape.
return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`))
}
42 changes: 42 additions & 0 deletions lib/sql/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestQuoteLiteral(t *testing.T) {
type _testCase struct {
name string
colVal string
expectedString string
}

testCases := []_testCase{
{
name: "string",
colVal: "hello",
expectedString: "'hello'",
},
{
name: "string that requires escaping",
colVal: "bobby o'reilly",
expectedString: `'bobby o\'reilly'`,
},
{
name: "string with line breaks",
colVal: "line1 \n line 2",
expectedString: "'line1 \n line 2'",
},
{
name: "string with existing backslash",
colVal: `hello \ there \ hh`,
expectedString: `'hello \\ there \\ hh'`,
},
}

for _, testCase := range testCases {
assert.Equal(t, testCase.expectedString, QuoteLiteral(testCase.colVal), testCase.name)
}
}
13 changes: 2 additions & 11 deletions lib/stringutil/strings.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package stringutil

import (
"fmt"
"math/rand"
"strings"
)
Expand All @@ -26,16 +25,8 @@ func Override(vals ...string) string {
return retVal
}

func Wrap(colVal any, noQuotes bool) string {
colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`)
// The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \'
if noQuotes {
return fmt.Sprint(colVal)
}

// When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles.
// However, if there are no quote wrapping, we should not need to escape.
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`))
func EscapeBackslashes(value string) string {
return strings.ReplaceAll(value, `\`, `\\`)
}

func Empty(vals ...string) bool {
Expand Down
39 changes: 4 additions & 35 deletions lib/stringutil/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,69 +68,38 @@ func TestOverride(t *testing.T) {
}
}

func TestWrap(t *testing.T) {
func TestEscapeBackslashes(t *testing.T) {
type _testCase struct {
name string
colVal any
noQuotes bool
colVal string
expectedString string
}

testCases := []_testCase{
{
name: "string",
colVal: "hello",
expectedString: "'hello'",
},
{
name: "string (no quotes)",
colVal: "hello",
noQuotes: true,
expectedString: "hello",
},
{
name: "string (no quotes)",
name: "string",
colVal: "bobby o'reilly",
noQuotes: true,
expectedString: "bobby o'reilly",
},
{
name: "string that requires escaping",
colVal: "bobby o'reilly",
expectedString: `'bobby o\'reilly'`,
},
{
name: "string that requires escaping (no quotes)",
colVal: "bobby o'reilly",
expectedString: `bobby o'reilly`,
noQuotes: true,
},
{
name: "string with line breaks",
colVal: "line1 \n line 2",
expectedString: "'line1 \n line 2'",
},
{
name: "string with line breaks (no quotes)",
colVal: "line1 \n line 2",
expectedString: "line1 \n line 2",
noQuotes: true,
},
{
name: "string with existing backslash",
colVal: `hello \ there \ hh`,
expectedString: `'hello \\ there \\ hh'`,
},
{
name: "string with existing backslash (no quotes)",
colVal: `hello \ there \ hh`,
expectedString: `hello \\ there \\ hh`,
noQuotes: true,
},
}

for _, testCase := range testCases {
assert.Equal(t, testCase.expectedString, Wrap(testCase.colVal, testCase.noQuotes), testCase.name)
assert.Equal(t, testCase.expectedString, EscapeBackslashes(testCase.colVal), testCase.name)
}
}

Expand Down
7 changes: 3 additions & 4 deletions lib/typing/columns/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/artie-labs/transfer/lib/typing/ext"

"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/decimal"
)
Expand Down Expand Up @@ -36,9 +35,9 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string)

switch c.KindDetails.ExtendedTimeDetails.Type {
case ext.TimeKindType:
return stringutil.Wrap(extTime.String(ext.PostgresTimeFormatNoTZ), false), nil
return sql.QuoteLiteral(extTime.String(ext.PostgresTimeFormatNoTZ)), nil
default:
return stringutil.Wrap(extTime.String(c.KindDetails.ExtendedTimeDetails.Format), false), nil
return sql.QuoteLiteral(extTime.String(c.KindDetails.ExtendedTimeDetails.Format)), nil
}
case typing.EDecimal.Kind:
val, isOk := c.defaultValue.(*decimal.Decimal)
Expand All @@ -48,7 +47,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string)

return val.Value(), nil
case typing.String.Kind:
return stringutil.Wrap(c.defaultValue, false), nil
return sql.QuoteLiteral(fmt.Sprint(c.defaultValue)), nil
}

return c.defaultValue, nil
Expand Down
2 changes: 1 addition & 1 deletion lib/typing/values/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) (
return string(colValBytes), nil
}

return stringutil.Wrap(colVal, true), nil
return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil
case typing.Struct.Kind:
if colKind.KindDetails == typing.Struct {
if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) {
Expand Down

0 comments on commit 6371fdb

Please sign in to comment.