Skip to content

Commit

Permalink
Split stringutil.Wrap into two functions (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored May 2, 2024
1 parent 4153191 commit efec8fd
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 66 deletions.
2 changes: 1 addition & 1 deletion lib/array/strings.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func InterfaceToArrayString(val any, recastAsArray bool) ([]string, error) {

vals = append(vals, string(bytes))
} else {
vals = append(vals, stringutil.Wrap(value, true))
vals = append(vals, stringutil.EscapeBackslashes(fmt.Sprint(value)))
}
}

Expand Down
17 changes: 8 additions & 9 deletions lib/sql/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@ import (
"strings"

"github.com/artie-labs/transfer/lib/config/constants"
"github.com/artie-labs/transfer/lib/stringutil"
)

type Dialect interface {
NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything
QuoteIdentifier(identifier string) string
EscapeStruct(value any) string
EscapeStruct(value string) string
}

type BigQueryDialect struct{}
Expand All @@ -26,8 +25,8 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf("`%s`", identifier)
}

func (BigQueryDialect) EscapeStruct(value any) string {
return "JSON" + stringutil.Wrap(value, false)
func (BigQueryDialect) EscapeStruct(value string) string {
return "JSON" + QuoteLiteral(value)
}

type MSSQLDialect struct{}
Expand All @@ -38,7 +37,7 @@ func (MSSQLDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

func (MSSQLDialect) EscapeStruct(value any) string {
func (MSSQLDialect) EscapeStruct(value string) string {
panic("not implemented") // We don't currently support backfills for MS SQL.
}

Expand All @@ -51,8 +50,8 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, strings.ToLower(identifier))
}

func (RedshiftDialect) EscapeStruct(value any) string {
return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false))
func (RedshiftDialect) EscapeStruct(value string) string {
return fmt.Sprintf("JSON_PARSE(%s)", QuoteLiteral(value))
}

type SnowflakeDialect struct {
Expand Down Expand Up @@ -90,6 +89,6 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string {
return fmt.Sprintf(`"%s"`, identifier)
}

func (SnowflakeDialect) EscapeStruct(value any) string {
return stringutil.Wrap(value, false)
func (SnowflakeDialect) EscapeStruct(value string) string {
return QuoteLiteral(value)
}
15 changes: 15 additions & 0 deletions lib/sql/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package sql

import (
"fmt"
"strings"

"github.com/artie-labs/transfer/lib/stringutil"
)

// QuoteLiteral wraps a string with single quotes so that it can be used in a SQL query.
// If there are backslashes in the string, then they will be escaped to [\\].
// After escaping backslashes, any remaining single quotes will be replaced with [\'].
func QuoteLiteral(value string) string {
return fmt.Sprintf("'%s'", strings.ReplaceAll(stringutil.EscapeBackslashes(value), "'", `\'`))
}
40 changes: 40 additions & 0 deletions lib/sql/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package sql

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestQuoteLiteral(t *testing.T) {
testCases := []struct {
name string
colVal string
expected string
}{
{
name: "string",
colVal: "hello",
expected: "'hello'",
},
{
name: "string that requires escaping",
colVal: "bobby o'reilly",
expected: `'bobby o\'reilly'`,
},
{
name: "string with line breaks",
colVal: "line1 \n line 2",
expected: "'line1 \n line 2'",
},
{
name: "string with existing backslash",
colVal: `hello \ there \ hh`,
expected: `'hello \\ there \\ hh'`,
},
}

for _, testCase := range testCases {
assert.Equal(t, testCase.expected, QuoteLiteral(testCase.colVal), testCase.name)
}
}
13 changes: 2 additions & 11 deletions lib/stringutil/strings.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package stringutil

import (
"fmt"
"math/rand"
"strings"
)
Expand All @@ -26,16 +25,8 @@ func Override(vals ...string) string {
return retVal
}

func Wrap(colVal any, noQuotes bool) string {
colVal = strings.ReplaceAll(fmt.Sprint(colVal), `\`, `\\`)
// The normal string escape is to do for O'Reilly is O\\'Reilly, but Snowflake escapes via \'
if noQuotes {
return fmt.Sprint(colVal)
}

// When there is quote wrapping `foo -> 'foo'`, we'll need to escape `'` so the value compiles.
// However, if there are no quote wrapping, we should not need to escape.
return fmt.Sprintf("'%s'", strings.ReplaceAll(fmt.Sprint(colVal), "'", `\'`))
func EscapeBackslashes(value string) string {
return strings.ReplaceAll(value, `\`, `\\`)
}

func Empty(vals ...string) bool {
Expand Down
45 changes: 6 additions & 39 deletions lib/stringutil/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,69 +68,36 @@ func TestOverride(t *testing.T) {
}
}

func TestWrap(t *testing.T) {
type _testCase struct {
func TestEscapeBackslashes(t *testing.T) {
testCases := []struct {
name string
colVal any
noQuotes bool
colVal string
expectedString string
}

testCases := []_testCase{
}{
{
name: "string",
colVal: "hello",
expectedString: "'hello'",
},
{
name: "string (no quotes)",
colVal: "hello",
noQuotes: true,
expectedString: "hello",
},
{
name: "string (no quotes)",
name: "string",
colVal: "bobby o'reilly",
noQuotes: true,
expectedString: "bobby o'reilly",
},
{
name: "string that requires escaping",
colVal: "bobby o'reilly",
expectedString: `'bobby o\'reilly'`,
},
{
name: "string that requires escaping (no quotes)",
colVal: "bobby o'reilly",
expectedString: `bobby o'reilly`,
noQuotes: true,
},
{
name: "string with line breaks",
colVal: "line1 \n line 2",
expectedString: "'line1 \n line 2'",
},
{
name: "string with line breaks (no quotes)",
colVal: "line1 \n line 2",
expectedString: "line1 \n line 2",
noQuotes: true,
},
{
name: "string with existing backslash",
colVal: `hello \ there \ hh`,
expectedString: `'hello \\ there \\ hh'`,
},
{
name: "string with existing backslash (no quotes)",
colVal: `hello \ there \ hh`,
expectedString: `hello \\ there \\ hh`,
noQuotes: true,
},
}

for _, testCase := range testCases {
assert.Equal(t, testCase.expectedString, Wrap(testCase.colVal, testCase.noQuotes), testCase.name)
assert.Equal(t, testCase.expectedString, EscapeBackslashes(testCase.colVal), testCase.name)
}
}

Expand Down
9 changes: 4 additions & 5 deletions lib/typing/columns/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/artie-labs/transfer/lib/typing/ext"

"github.com/artie-labs/transfer/lib/stringutil"
"github.com/artie-labs/transfer/lib/typing"
"github.com/artie-labs/transfer/lib/typing/decimal"
)
Expand All @@ -23,7 +22,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string)

switch c.KindDetails.Kind {
case typing.Struct.Kind, typing.Array.Kind:
return dialect.EscapeStruct(c.defaultValue), nil
return dialect.EscapeStruct(fmt.Sprint(c.defaultValue)), nil
case typing.ETime.Kind:
if c.KindDetails.ExtendedTimeDetails == nil {
return nil, fmt.Errorf("column kind details for extended time is nil")
Expand All @@ -36,9 +35,9 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string)

switch c.KindDetails.ExtendedTimeDetails.Type {
case ext.TimeKindType:
return stringutil.Wrap(extTime.String(ext.PostgresTimeFormatNoTZ), false), nil
return sql.QuoteLiteral(extTime.String(ext.PostgresTimeFormatNoTZ)), nil
default:
return stringutil.Wrap(extTime.String(c.KindDetails.ExtendedTimeDetails.Format), false), nil
return sql.QuoteLiteral(extTime.String(c.KindDetails.ExtendedTimeDetails.Format)), nil
}
case typing.EDecimal.Kind:
val, isOk := c.defaultValue.(*decimal.Decimal)
Expand All @@ -48,7 +47,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string)

return val.Value(), nil
case typing.String.Kind:
return stringutil.Wrap(c.defaultValue, false), nil
return sql.QuoteLiteral(fmt.Sprint(c.defaultValue)), nil
}

return c.defaultValue, nil
Expand Down
2 changes: 1 addition & 1 deletion lib/typing/values/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) (
return string(colValBytes), nil
}

return stringutil.Wrap(colVal, true), nil
return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil
case typing.Struct.Kind:
if colKind.KindDetails == typing.Struct {
if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) {
Expand Down

0 comments on commit efec8fd

Please sign in to comment.