From 415319168ed070283228491683db5067620ebf56 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Wed, 1 May 2024 17:41:37 -0700 Subject: [PATCH] [sql] Add `Dialect.EscapeStruct` method (#537) --- lib/sql/dialect.go | 18 ++++++++++++++++++ lib/typing/columns/default.go | 12 +----------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/lib/sql/dialect.go b/lib/sql/dialect.go index 9a7fcfa22..ae39b254e 100644 --- a/lib/sql/dialect.go +++ b/lib/sql/dialect.go @@ -8,11 +8,13 @@ import ( "strings" "github.com/artie-labs/transfer/lib/config/constants" + "github.com/artie-labs/transfer/lib/stringutil" ) type Dialect interface { NeedsEscaping(identifier string) bool // TODO: Remove this when we escape everything QuoteIdentifier(identifier string) string + EscapeStruct(value any) string } type BigQueryDialect struct{} @@ -24,6 +26,10 @@ func (BigQueryDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf("`%s`", identifier) } +func (BigQueryDialect) EscapeStruct(value any) string { + return "JSON" + stringutil.Wrap(value, false) +} + type MSSQLDialect struct{} func (MSSQLDialect) NeedsEscaping(_ string) bool { return true } @@ -32,6 +38,10 @@ func (MSSQLDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } +func (MSSQLDialect) EscapeStruct(value any) string { + panic("not implemented") // We don't currently support backfills for MS SQL. +} + type RedshiftDialect struct{} func (RedshiftDialect) NeedsEscaping(_ string) bool { return true } @@ -41,6 +51,10 @@ func (rd RedshiftDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, strings.ToLower(identifier)) } +func (RedshiftDialect) EscapeStruct(value any) string { + return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(value, false)) +} + type SnowflakeDialect struct { UppercaseEscNames bool } @@ -75,3 +89,7 @@ func (sd SnowflakeDialect) QuoteIdentifier(identifier string) string { return fmt.Sprintf(`"%s"`, identifier) } + +func (SnowflakeDialect) EscapeStruct(value any) string { + return stringutil.Wrap(value, false) +} diff --git a/lib/typing/columns/default.go b/lib/typing/columns/default.go index da058a56e..a5155777d 100644 --- a/lib/typing/columns/default.go +++ b/lib/typing/columns/default.go @@ -23,17 +23,7 @@ func (c *Column) DefaultValue(dialect sql.Dialect, additionalDateFmts []string) switch c.KindDetails.Kind { case typing.Struct.Kind, typing.Array.Kind: - switch dialect.(type) { - case sql.BigQueryDialect: - return "JSON" + stringutil.Wrap(c.defaultValue, false), nil - case sql.RedshiftDialect: - return fmt.Sprintf("JSON_PARSE(%s)", stringutil.Wrap(c.defaultValue, false)), nil - case sql.SnowflakeDialect: - return stringutil.Wrap(c.defaultValue, false), nil - default: - // Note that we don't currently support backfills for MS SQL. - return nil, fmt.Errorf("not implemented for %v dialect", dialect) - } + return dialect.EscapeStruct(c.defaultValue), nil case typing.ETime.Kind: if c.KindDetails.ExtendedTimeDetails == nil { return nil, fmt.Errorf("column kind details for extended time is nil")