From a8f99c063ad604ef081901dcd01f9b0b21104b7d Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Sun, 28 Apr 2024 22:07:26 -0700 Subject: [PATCH] [typing] Add `Columns.GetEscapedColumnsToUpdate` method (#507) --- clients/bigquery/bigquery.go | 2 +- clients/mssql/staging.go | 2 +- clients/redshift/staging.go | 2 +- clients/s3/s3.go | 2 +- clients/snowflake/staging.go | 6 +-- lib/destination/dml/merge.go | 12 ++---- lib/typing/columns/columns.go | 34 ++++++++++----- lib/typing/columns/columns_test.go | 68 ++++++++++++++++++++++++------ 8 files changed, 89 insertions(+), 39 deletions(-) diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 2f983a3a9..3c649f56d 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -59,7 +59,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { data := make(map[string]bigquery.Value) - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) colVal, err := castColVal(value[col], colKind, additionalDateFmts) if err != nil { diff --git a/clients/mssql/staging.go b/clients/mssql/staging.go index 9daf8faa1..f373fbfb6 100644 --- a/clients/mssql/staging.go +++ b/clients/mssql/staging.go @@ -42,7 +42,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo } }() - columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) + columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns...)) if err != nil { return fmt.Errorf("failed to prepare bulk insert: %w", err) diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index c7155ae47..bb6614baa 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -90,7 +90,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) castedValue, castErr := s.CastColValStaging(value[col], colKind, additionalDateFmts) if castErr != nil { diff --git a/clients/s3/s3.go b/clients/s3/s3.go index 6e85c429f..688d14a7e 100644 --- a/clients/s3/s3.go +++ b/clients/s3/s3.go @@ -113,7 +113,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error { pw.CompressionType = parquet.CompressionCodec_GZIP for _, val := range tableData.Rows() { row := make(map[string]any) - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(false, nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, isOk := tableData.ReadOnlyInMemoryCols().GetColumn(col) if !isOk { return fmt.Errorf("expected column: %v to exist in readOnlyInMemoryCols(...) but it does not", col) diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index e454723a4..91e07e286 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -85,9 +85,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo // COPY the CSV file (in Snowflake) into a table copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", tempTableID.FullyQualifiedName(), - strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{ - DestKind: s.Label(), - }), ","), + strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.ShouldUppercaseEscapedNames(), s.Label()), ","), escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%")) if additionalSettings.AdditionalCopyClause != "" { @@ -115,7 +113,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { column, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) castedValue, castErr := castColValStaging(value[col], column, additionalDateFmts) if castErr != nil { diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 0220e17d3..0a551beb0 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -97,9 +97,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return []string{ @@ -231,9 +229,7 @@ func (m *MergeArgument) GetStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return fmt.Sprintf(` @@ -302,9 +298,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return fmt.Sprintf(` diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index abd6ea4a7..2ed85ff2f 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -83,10 +83,6 @@ func (c *Column) RawName() string { return c.name } -type NameArgs struct { - DestKind constants.DestinationKind -} - // Name will give you c.name // Plus we will escape it if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. @@ -179,8 +175,8 @@ func (c *Columns) GetColumn(name string) (Column, bool) { } // GetColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. -// It also has an option to escape the returned columns or not. This is used mostly for the SQL MERGE queries. -func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []string { +// This is used mostly for the SQL MERGE queries. +func (c *Columns) GetColumnsToUpdate() []string { if c == nil { return []string{} } @@ -194,11 +190,29 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s continue } - if args == nil { - cols = append(cols, col.RawName()) - } else { - cols = append(cols, col.Name(uppercaseEscNames, args.DestKind)) + cols = append(cols, col.RawName()) + } + + return cols +} + +// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. +// It will escape the returned columns. +func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind constants.DestinationKind) []string { + if c == nil { + return []string{} + } + + c.RLock() + defer c.RUnlock() + + var cols []string + for _, col := range c.columns { + if col.KindDetails == typing.Invalid { + continue } + + cols = append(cols, col.Name(uppercaseEscNames, destKind)) } return cols diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index a67f09dab..8a2ef5b91 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -175,10 +175,63 @@ func TestColumn_Name(t *testing.T) { } func TestColumns_GetColumnsToUpdate(t *testing.T) { + type _testCase struct { + name string + cols []Column + expectedCols []string + } + + var ( + happyPathCols = []Column{ + { + name: "hi", + KindDetails: typing.String, + }, + { + name: "bye", + KindDetails: typing.String, + }, + { + name: "start", + KindDetails: typing.String, + }, + } + ) + + extraCols := happyPathCols + for i := 0; i < 100; i++ { + extraCols = append(extraCols, Column{ + name: fmt.Sprintf("hello_%v", i), + KindDetails: typing.Invalid, + }) + } + + testCases := []_testCase{ + { + name: "happy path", + cols: happyPathCols, + expectedCols: []string{"hi", "bye", "start"}, + }, + { + name: "happy path + extra col", + cols: extraCols, + expectedCols: []string{"hi", "bye", "start"}, + }, + } + + for _, testCase := range testCases { + cols := &Columns{ + columns: testCase.cols, + } + + assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(), testCase.name) + } +} + +func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { type _testCase struct { name string cols []Column - expectedCols []string expectedColsEsc []string expectedColsEscBq []string } @@ -212,14 +265,12 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { { name: "happy path", cols: happyPathCols, - expectedCols: []string{"hi", "bye", "start"}, expectedColsEsc: []string{"hi", "bye", `"start"`}, expectedColsEscBq: []string{"hi", "bye", "`start`"}, }, { name: "happy path + extra col", cols: extraCols, - expectedCols: []string{"hi", "bye", "start"}, expectedColsEsc: []string{"hi", "bye", `"start"`}, expectedColsEscBq: []string{"hi", "bye", "`start`"}, }, @@ -230,15 +281,8 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(false, nil), testCase.name) - - assert.Equal(t, testCase.expectedColsEsc, cols.GetColumnsToUpdate(false, &NameArgs{ - DestKind: constants.Snowflake, - }), testCase.name) - - assert.Equal(t, testCase.expectedColsEscBq, cols.GetColumnsToUpdate(false, &NameArgs{ - DestKind: constants.BigQuery, - }), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(false, constants.Snowflake), testCase.name) + assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(false, constants.BigQuery), testCase.name) } }