Skip to content

Commit

Permalink
[typing] Add Columns.GetEscapedColumnsToUpdate method (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathan-artie authored Apr 29, 2024
1 parent e0eab1e commit a8f99c0
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 39 deletions.
2 changes: 1 addition & 1 deletion clients/bigquery/bigquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
data := make(map[string]bigquery.Value)
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
colVal, err := castColVal(value[col], colKind, additionalDateFmts)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion clients/mssql/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
}
}()

columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil)
columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate()
stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns...))
if err != nil {
return fmt.Errorf("failed to prepare bulk insert: %w", err)
Expand Down
2 changes: 1 addition & 1 deletion clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
var row []string
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
castedValue, castErr := s.CastColValStaging(value[col], colKind, additionalDateFmts)
if castErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion clients/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error {
pw.CompressionType = parquet.CompressionCodec_GZIP
for _, val := range tableData.Rows() {
row := make(map[string]any)
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(false, nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
colKind, isOk := tableData.ReadOnlyInMemoryCols().GetColumn(col)
if !isOk {
return fmt.Errorf("expected column: %v to exist in readOnlyInMemoryCols(...) but it does not", col)
Expand Down
6 changes: 2 additions & 4 deletions clients/snowflake/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo
// COPY the CSV file (in Snowflake) into a table
copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)",
tempTableID.FullyQualifiedName(),
strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{
DestKind: s.Label(),
}), ","),
strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.ShouldUppercaseEscapedNames(), s.Label()), ","),
escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%"))

if additionalSettings.AdditionalCopyClause != "" {
Expand Down Expand Up @@ -115,7 +113,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa
additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats
for _, value := range tableData.Rows() {
var row []string
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) {
for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() {
column, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col)
castedValue, castErr := castColValStaging(value[col], column, additionalDateFmts)
if castErr != nil {
Expand Down
12 changes: 3 additions & 9 deletions lib/destination/dml/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ func (m *MergeArgument) GetParts() ([]string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return []string{
Expand Down Expand Up @@ -231,9 +229,7 @@ func (m *MergeArgument) GetStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return fmt.Sprintf(`
Expand Down Expand Up @@ -302,9 +298,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) {
equalitySQLParts = append(equalitySQLParts, equalitySQL)
}

cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{
DestKind: m.DestKind,
})
cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind)

if m.SoftDelete {
return fmt.Sprintf(`
Expand Down
34 changes: 24 additions & 10 deletions lib/typing/columns/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ func (c *Column) RawName() string {
return c.name
}

type NameArgs struct {
DestKind constants.DestinationKind
}

// Name will give you c.name
// Plus we will escape it if the column name is part of the reserved words from destinations.
// If so, it'll change from `start` => `"start"` as suggested by Snowflake.
Expand Down Expand Up @@ -179,8 +175,8 @@ func (c *Columns) GetColumn(name string) (Column, bool) {
}

// GetColumnsToUpdate will filter all the `Invalid` columns so that we do not update it.
// It also has an option to escape the returned columns or not. This is used mostly for the SQL MERGE queries.
func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []string {
// This is used mostly for the SQL MERGE queries.
func (c *Columns) GetColumnsToUpdate() []string {
if c == nil {
return []string{}
}
Expand All @@ -194,11 +190,29 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s
continue
}

if args == nil {
cols = append(cols, col.RawName())
} else {
cols = append(cols, col.Name(uppercaseEscNames, args.DestKind))
cols = append(cols, col.RawName())
}

return cols
}

// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it.
// It will escape the returned columns.
func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind constants.DestinationKind) []string {
if c == nil {
return []string{}
}

c.RLock()
defer c.RUnlock()

var cols []string
for _, col := range c.columns {
if col.KindDetails == typing.Invalid {
continue
}

cols = append(cols, col.Name(uppercaseEscNames, destKind))
}

return cols
Expand Down
68 changes: 56 additions & 12 deletions lib/typing/columns/columns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,63 @@ func TestColumn_Name(t *testing.T) {
}

func TestColumns_GetColumnsToUpdate(t *testing.T) {
type _testCase struct {
name string
cols []Column
expectedCols []string
}

var (
happyPathCols = []Column{
{
name: "hi",
KindDetails: typing.String,
},
{
name: "bye",
KindDetails: typing.String,
},
{
name: "start",
KindDetails: typing.String,
},
}
)

extraCols := happyPathCols
for i := 0; i < 100; i++ {
extraCols = append(extraCols, Column{
name: fmt.Sprintf("hello_%v", i),
KindDetails: typing.Invalid,
})
}

testCases := []_testCase{
{
name: "happy path",
cols: happyPathCols,
expectedCols: []string{"hi", "bye", "start"},
},
{
name: "happy path + extra col",
cols: extraCols,
expectedCols: []string{"hi", "bye", "start"},
},
}

for _, testCase := range testCases {
cols := &Columns{
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(), testCase.name)
}
}

func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) {
type _testCase struct {
name string
cols []Column
expectedCols []string
expectedColsEsc []string
expectedColsEscBq []string
}
Expand Down Expand Up @@ -212,14 +265,12 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) {
{
name: "happy path",
cols: happyPathCols,
expectedCols: []string{"hi", "bye", "start"},
expectedColsEsc: []string{"hi", "bye", `"start"`},
expectedColsEscBq: []string{"hi", "bye", "`start`"},
},
{
name: "happy path + extra col",
cols: extraCols,
expectedCols: []string{"hi", "bye", "start"},
expectedColsEsc: []string{"hi", "bye", `"start"`},
expectedColsEscBq: []string{"hi", "bye", "`start`"},
},
Expand All @@ -230,15 +281,8 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) {
columns: testCase.cols,
}

assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(false, nil), testCase.name)

assert.Equal(t, testCase.expectedColsEsc, cols.GetColumnsToUpdate(false, &NameArgs{
DestKind: constants.Snowflake,
}), testCase.name)

assert.Equal(t, testCase.expectedColsEscBq, cols.GetColumnsToUpdate(false, &NameArgs{
DestKind: constants.BigQuery,
}), testCase.name)
assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(false, constants.Snowflake), testCase.name)
assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(false, constants.BigQuery), testCase.name)
}
}

Expand Down

0 comments on commit a8f99c0

Please sign in to comment.