From a2f48aeda8a6b3292c54ba36416dacbcb51dc6f6 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Sun, 28 Apr 2024 21:33:33 -0700 Subject: [PATCH 1/3] [mssql] Escape all columns that do not start with "__artie" (#501) --- lib/config/constants/mssql.go | 190 ------------------------ lib/destination/dml/merge_mssql_test.go | 6 +- lib/sql/escape.go | 2 +- 3 files changed, 4 insertions(+), 194 deletions(-) delete mode 100644 lib/config/constants/mssql.go diff --git a/lib/config/constants/mssql.go b/lib/config/constants/mssql.go deleted file mode 100644 index c3f8ca62f..000000000 --- a/lib/config/constants/mssql.go +++ /dev/null @@ -1,190 +0,0 @@ -package constants - -// MSSQLReservedKeywords https://learn.microsoft.com/en-us/sql/t-sql/language-elements/reserved-keywords-transact-sql?view=sql-server-ver16 -var MSSQLReservedKeywords = []string{ - "add", - "external", - "procedure", - "all", - "fetch", - "public", - "alter", - "file", - "raiserror", - "and", - "fillfactor", - "read", - "any", - "for", - "readtext", - "as", - "foreign", - "reconfigure", - "asc", - "freetext", - "references", - "authorization", - "freetexttable", - "replication", - "backup", - "from", - "restore", - "begin", - "full", - "restrict", - "between", - "function", - "return", - "break", - "goto", - "revert", - "browse", - "grant", - "revoke", - "bulk", - "group", - "right", - "by", - "having", - "rollback", - "cascade", - "holdlock", - "rowcount", - "case", - "identity", - "rowguidcol", - "check", - "identity_insert", - "rule", - "checkpoint", - "identitycol", - "save", - "close", - "if", - "schema", - "clustered", - "in", - "securityaudit", - "coalesce", - "index", - "select", - "collate", - "inner", - "semantickeyphrasetable", - "column", - "insert", - "semanticsimilaritydetailstable", - "commit", - "intersect", - "semanticsimilaritytable", - "compute", - "into", - "session_user", - "constraint", - "is", - "set", - "contains", - "join", - "setuser", - "containstable", - "key", - "shutdown", - "continue", - "kill", - "some", - "convert", - "left", - "statistics", - "create", - "like", - "system_user", - "cross", - "lineno", - "table", - "current", - "load", - "tablesample", - "current_date", - "merge", - "textsize", - "current_time", - "national", - "then", - "current_timestamp", - "nocheck", - "to", - "current_user", - "nonclustered", - "top", - "cursor", - "not", - "tran", - "database", - "null", - "transaction", - "dbcc", - "nullif", - "trigger", - "deallocate", - "of", - "truncate", - "declare", - "off", - "try_convert", - "default", - "offsets", - "tsequal", - "delete", - "on", - "union", - "deny", - "open", - "unique", - "desc", - "opendatasource", - "unpivot", - "disk", - "openquery", - "update", - "distinct", - "openrowset", - "updatetext", - "distributed", - "openxml", - "use", - "double", - "option", - "user", - "drop", - "or", - "values", - "dump", - "order", - "varying", - "else", - "outer", - "view", - "end", - "over", - "waitfor", - "errlvl", - "percent", - "when", - "escape", - "pivot", - "where", - "except", - "plan", - "while", - "exec", - "precision", - "with", - "execute", - "primary", - "within group", - "exists", - "print", - "writetext", - "exit", - "proc", -} diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index cd3f60a1c..7a254dd38 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -54,11 +54,11 @@ func Test_GetMSSQLStatement(t *testing.T) { mergeSQL, err := mergeArg.GetMSSQLStatement() assert.NoError(t, err) assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL) - assert.NotContains(t, mergeSQL, fmt.Sprintf("cc.%s >= c.%s", "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) // Check primary keys clause assert.Contains(t, mergeSQL, "AS cc ON c.id = cc.id", mergeSQL) - assert.Contains(t, mergeSQL, `SET id=cc.id,bar=cc.bar,updated_at=cc.updated_at,start=cc.start`, mergeSQL) + assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL) assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL) - assert.Contains(t, mergeSQL, `cc.id,cc.bar,cc.updated_at,cc.start`, mergeSQL) + assert.Contains(t, mergeSQL, `cc."id",cc."bar",cc."updated_at",cc."start"`, mergeSQL) } diff --git a/lib/sql/escape.go b/lib/sql/escape.go index 3e7061103..180f303dc 100644 --- a/lib/sql/escape.go +++ b/lib/sql/escape.go @@ -25,7 +25,7 @@ func NeedsEscaping(name string, destKind constants.DestinationKind) bool { if destKind == constants.Redshift { reservedKeywords = constants.RedshiftReservedKeywords } else if destKind == constants.MSSQL { - reservedKeywords = constants.MSSQLReservedKeywords + return !strings.HasPrefix(name, constants.ArtiePrefix) } else { reservedKeywords = constants.ReservedKeywords } From e0eab1eb8515ff2a6085d5ac2825a50340c67ce6 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Sun, 28 Apr 2024 22:04:13 -0700 Subject: [PATCH 2/3] [mssql] Fix test (#509) --- lib/destination/dml/merge_mssql_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go index 7a254dd38..385d4ace5 100644 --- a/lib/destination/dml/merge_mssql_test.go +++ b/lib/destination/dml/merge_mssql_test.go @@ -56,7 +56,7 @@ func Test_GetMSSQLStatement(t *testing.T) { assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL) assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) // Check primary keys clause - assert.Contains(t, mergeSQL, "AS cc ON c.id = cc.id", mergeSQL) + assert.Contains(t, mergeSQL, `AS cc ON c."id" = cc."id"`, mergeSQL) assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL) assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL) From a8f99c063ad604ef081901dcd01f9b0b21104b7d Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Sun, 28 Apr 2024 22:07:26 -0700 Subject: [PATCH 3/3] [typing] Add `Columns.GetEscapedColumnsToUpdate` method (#507) --- clients/bigquery/bigquery.go | 2 +- clients/mssql/staging.go | 2 +- clients/redshift/staging.go | 2 +- clients/s3/s3.go | 2 +- clients/snowflake/staging.go | 6 +-- lib/destination/dml/merge.go | 12 ++---- lib/typing/columns/columns.go | 34 ++++++++++----- lib/typing/columns/columns_test.go | 68 ++++++++++++++++++++++++------ 8 files changed, 89 insertions(+), 39 deletions(-) diff --git a/clients/bigquery/bigquery.go b/clients/bigquery/bigquery.go index 2f983a3a9..3c649f56d 100644 --- a/clients/bigquery/bigquery.go +++ b/clients/bigquery/bigquery.go @@ -59,7 +59,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { data := make(map[string]bigquery.Value) - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) colVal, err := castColVal(value[col], colKind, additionalDateFmts) if err != nil { diff --git a/clients/mssql/staging.go b/clients/mssql/staging.go index 9daf8faa1..f373fbfb6 100644 --- a/clients/mssql/staging.go +++ b/clients/mssql/staging.go @@ -42,7 +42,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo } }() - columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) + columns := tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() stmt, err := tx.Prepare(mssql.CopyIn(tempTableID.FullyQualifiedName(), mssql.BulkOptions{}, columns...)) if err != nil { return fmt.Errorf("failed to prepare bulk insert: %w", err) diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index c7155ae47..bb6614baa 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -90,7 +90,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) castedValue, castErr := s.CastColValStaging(value[col], colKind, additionalDateFmts) if castErr != nil { diff --git a/clients/s3/s3.go b/clients/s3/s3.go index 6e85c429f..688d14a7e 100644 --- a/clients/s3/s3.go +++ b/clients/s3/s3.go @@ -113,7 +113,7 @@ func (s *Store) Merge(tableData *optimization.TableData) error { pw.CompressionType = parquet.CompressionCodec_GZIP for _, val := range tableData.Rows() { row := make(map[string]any) - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(false, nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { colKind, isOk := tableData.ReadOnlyInMemoryCols().GetColumn(col) if !isOk { return fmt.Errorf("expected column: %v to exist in readOnlyInMemoryCols(...) but it does not", col) diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index e454723a4..91e07e286 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -85,9 +85,7 @@ func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableCo // COPY the CSV file (in Snowflake) into a table copyCommand := fmt.Sprintf("COPY INTO %s (%s) FROM (SELECT %s FROM @%s)", tempTableID.FullyQualifiedName(), - strings.Join(tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), &columns.NameArgs{ - DestKind: s.Label(), - }), ","), + strings.Join(tableData.ReadOnlyInMemoryCols().GetEscapedColumnsToUpdate(s.ShouldUppercaseEscapedNames(), s.Label()), ","), escapeColumns(tableData.ReadOnlyInMemoryCols(), ","), addPrefixToTableName(tempTableID, "%")) if additionalSettings.AdditionalCopyClause != "" { @@ -115,7 +113,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa additionalDateFmts := s.config.SharedTransferConfig.TypingSettings.AdditionalDateFormats for _, value := range tableData.Rows() { var row []string - for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate(s.ShouldUppercaseEscapedNames(), nil) { + for _, col := range tableData.ReadOnlyInMemoryCols().GetColumnsToUpdate() { column, _ := tableData.ReadOnlyInMemoryCols().GetColumn(col) castedValue, castErr := castColValStaging(value[col], column, additionalDateFmts) if castErr != nil { diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index 0220e17d3..0a551beb0 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -97,9 +97,7 @@ func (m *MergeArgument) GetParts() ([]string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return []string{ @@ -231,9 +229,7 @@ func (m *MergeArgument) GetStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, m.AdditionalEqualityStrings...) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return fmt.Sprintf(` @@ -302,9 +298,7 @@ func (m *MergeArgument) GetMSSQLStatement() (string, error) { equalitySQLParts = append(equalitySQLParts, equalitySQL) } - cols := m.Columns.GetColumnsToUpdate(*m.UppercaseEscNames, &columns.NameArgs{ - DestKind: m.DestKind, - }) + cols := m.Columns.GetEscapedColumnsToUpdate(*m.UppercaseEscNames, m.DestKind) if m.SoftDelete { return fmt.Sprintf(` diff --git a/lib/typing/columns/columns.go b/lib/typing/columns/columns.go index abd6ea4a7..2ed85ff2f 100644 --- a/lib/typing/columns/columns.go +++ b/lib/typing/columns/columns.go @@ -83,10 +83,6 @@ func (c *Column) RawName() string { return c.name } -type NameArgs struct { - DestKind constants.DestinationKind -} - // Name will give you c.name // Plus we will escape it if the column name is part of the reserved words from destinations. // If so, it'll change from `start` => `"start"` as suggested by Snowflake. @@ -179,8 +175,8 @@ func (c *Columns) GetColumn(name string) (Column, bool) { } // GetColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. -// It also has an option to escape the returned columns or not. This is used mostly for the SQL MERGE queries. -func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []string { +// This is used mostly for the SQL MERGE queries. +func (c *Columns) GetColumnsToUpdate() []string { if c == nil { return []string{} } @@ -194,11 +190,29 @@ func (c *Columns) GetColumnsToUpdate(uppercaseEscNames bool, args *NameArgs) []s continue } - if args == nil { - cols = append(cols, col.RawName()) - } else { - cols = append(cols, col.Name(uppercaseEscNames, args.DestKind)) + cols = append(cols, col.RawName()) + } + + return cols +} + +// GetEscapedColumnsToUpdate will filter all the `Invalid` columns so that we do not update it. +// It will escape the returned columns. +func (c *Columns) GetEscapedColumnsToUpdate(uppercaseEscNames bool, destKind constants.DestinationKind) []string { + if c == nil { + return []string{} + } + + c.RLock() + defer c.RUnlock() + + var cols []string + for _, col := range c.columns { + if col.KindDetails == typing.Invalid { + continue } + + cols = append(cols, col.Name(uppercaseEscNames, destKind)) } return cols diff --git a/lib/typing/columns/columns_test.go b/lib/typing/columns/columns_test.go index a67f09dab..8a2ef5b91 100644 --- a/lib/typing/columns/columns_test.go +++ b/lib/typing/columns/columns_test.go @@ -175,10 +175,63 @@ func TestColumn_Name(t *testing.T) { } func TestColumns_GetColumnsToUpdate(t *testing.T) { + type _testCase struct { + name string + cols []Column + expectedCols []string + } + + var ( + happyPathCols = []Column{ + { + name: "hi", + KindDetails: typing.String, + }, + { + name: "bye", + KindDetails: typing.String, + }, + { + name: "start", + KindDetails: typing.String, + }, + } + ) + + extraCols := happyPathCols + for i := 0; i < 100; i++ { + extraCols = append(extraCols, Column{ + name: fmt.Sprintf("hello_%v", i), + KindDetails: typing.Invalid, + }) + } + + testCases := []_testCase{ + { + name: "happy path", + cols: happyPathCols, + expectedCols: []string{"hi", "bye", "start"}, + }, + { + name: "happy path + extra col", + cols: extraCols, + expectedCols: []string{"hi", "bye", "start"}, + }, + } + + for _, testCase := range testCases { + cols := &Columns{ + columns: testCase.cols, + } + + assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(), testCase.name) + } +} + +func TestColumns_GetEscapedColumnsToUpdate(t *testing.T) { type _testCase struct { name string cols []Column - expectedCols []string expectedColsEsc []string expectedColsEscBq []string } @@ -212,14 +265,12 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { { name: "happy path", cols: happyPathCols, - expectedCols: []string{"hi", "bye", "start"}, expectedColsEsc: []string{"hi", "bye", `"start"`}, expectedColsEscBq: []string{"hi", "bye", "`start`"}, }, { name: "happy path + extra col", cols: extraCols, - expectedCols: []string{"hi", "bye", "start"}, expectedColsEsc: []string{"hi", "bye", `"start"`}, expectedColsEscBq: []string{"hi", "bye", "`start`"}, }, @@ -230,15 +281,8 @@ func TestColumns_GetColumnsToUpdate(t *testing.T) { columns: testCase.cols, } - assert.Equal(t, testCase.expectedCols, cols.GetColumnsToUpdate(false, nil), testCase.name) - - assert.Equal(t, testCase.expectedColsEsc, cols.GetColumnsToUpdate(false, &NameArgs{ - DestKind: constants.Snowflake, - }), testCase.name) - - assert.Equal(t, testCase.expectedColsEscBq, cols.GetColumnsToUpdate(false, &NameArgs{ - DestKind: constants.BigQuery, - }), testCase.name) + assert.Equal(t, testCase.expectedColsEsc, cols.GetEscapedColumnsToUpdate(false, constants.Snowflake), testCase.name) + assert.Equal(t, testCase.expectedColsEscBq, cols.GetEscapedColumnsToUpdate(false, constants.BigQuery), testCase.name) } }