Skip to content

Commit

Permalink
Merge branch 'master' into nv/always-uppercase-escaped-snowflake-names
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Apr 29, 2024
2 parents 9e42fd8 + a8f99c0 commit e01e7c5
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 234 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
data := make(map[string]bigquery.Value)
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
colVal, err := castColVal(value[col], colKind, additionalDateFmts)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}
}()

columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil)
columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate()
stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns...))
if err != nil {
return fmt.Errorf("failed to prepare bulk insert: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
var row []string
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
castedValue, castErr := s.CastColValStaging(value[col], colKind, additionalDateFmts)
if castErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
pw.CompressionType = parquet.CompressionCodec_GZIP
for _, val := range tableData.Rows() {
row := make(map[string]any)
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(false, nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, isOk := tableData.ReadOnlyInMemoryCols().GetColumn(col)
if !isOk {
return fmt.Errorf("expected column: %v to exist in readOnlyInMemoryCols(...) but it does not", col)
Expand Down
6 changes: 2 additions & 4 deletions clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
// COPY the CSV file (in Snowflake) into a table
copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)",
tempTableID.FullyQualifiedName(),
strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{
DestKind: s.Label(),
}), ","),
strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.ShouldUppercaseEscapedNames(), s.Label()), ","),
escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%"))

if additionalSettings.AdditionalCopyClause != "" {
Expand Down Expand Up @@ -115,7 +113,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
var row []string
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
column, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
castedValue, castErr := castColValStaging(value[col], column, additionalDateFmts)
if castErr != nil {
Expand Down
190 changes: 0 additions & 190 deletions lib/config/constants/mssql.go

This file was deleted.

12 changes: 3 additions & 9 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return []string{
Expand Down Expand Up @@ -231,9 +229,7 @@ func (m *MergeArgument) GetStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return fmt.Sprintf(`
Expand Down Expand Up @@ -302,9 +298,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return fmt.Sprintf(`
Expand Down
8 changes: 4 additions & 4 deletions lib/destination/dml/merge_mssql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ func Test_GetMSSQLStatement(t *testing.T) {
mergeSQL, err := mergeArg.GetMSSQLStatement()
assert.NoError(t, err)
assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL)
assert.NotContains(t, mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL))
assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL))
// Check primary keys clause
assert.Contains(t, mergeSQL, "AS cc ON c.id = cc.id", mergeSQL)
assert.Contains(t, mergeSQL, `AS cc ON c."id" = cc."id"`, mergeSQL)

assert.Contains(t, mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,start=cc.start`, mergeSQL)
assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL)
assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL)
assert.Contains(t, mergeSQL, `cc.id,cc.bar,cc.updated_at,cc.start`, mergeSQL)
assert.Contains(t, mergeSQL, `cc."id",cc."bar",cc."updated_at",cc."start"`, mergeSQL)
}
2 changes: 1 addition & 1 deletion lib/sql/escape.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func NeedsEscaping(name string, destKind constants.DestinationKind) bool {
if destKind == constants.Redshift {
reservedKeywords = constants.RedshiftReservedKeywords
} else if destKind == constants.MSSQL {
reservedKeywords = constants.MSSQLReservedKeywords
return !strings.HasPrefix(name, constants.ArtiePrefix)
} else {
reservedKeywords = constants.ReservedKeywords
}
Expand Down
34 changes: 24 additions & 10 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ func (c *Column) RawName() string {
return c.name
}

type NameArgs struct {
DestKind constants.DestinationKind
}

// Name will give you c.name
// Plus we will escape it if the column name is part of the reserved words from destinations.
// If so, it'll change from `start` => `"start"` as suggested by Snowflake.
Expand Down Expand Up @@ -179,8 +175,8 @@ func (c *Columns) GetColumn(name string) (Column, bool) {
}

// GetColumnsToUpdate will filter all the `Invalid` columns so that we do not update it.
// It also has an option to escape the returned columns or not. This is used mostly for the SQL MERGE queries.
func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []string {
// This is used mostly for the SQL MERGE queries.
func (c *Columns) GetColumnsToUpdate() []string {
if c == nil {
return []string{}
}
Expand All @@ -194,11 +190,29 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s
continue
}

if args == nil {
cols = append(cols, col.RawName())
} else {
cols = append(cols, col.Name(uppercaseEscNames, args.DestKind))
cols = append(cols, col.RawName())
}

return cols
}

// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it.
// It will escape the returned columns.
func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind constants.DestinationKind) []string {
if c == nil {
return []string{}
}

c.RLock()
defer c.RUnlock()

var cols []string
for _, col := range c.columns {
if col.KindDetails == typing.Invalid {
continue
}

cols = append(cols, col.Name(uppercaseEscNames, destKind))
}

return cols
Expand Down
Loading

0 comments on commit e01e7c5

Please sign in to comment.