diff --git a/clients/redshift/cast.go b/clients/redshift/cast.go index 3e4920ea2..5383ce008 100644 --- a/clients/redshift/cast.go +++ b/clients/redshift/cast.go @@ -5,25 +5,22 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/typing" - "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/values" ) const maxRedshiftLength int32 = 65535 -// replaceExceededValues - takes `colVal` any and `colKind` columns.Column and replaces the value with an empty string if it exceeds the max length. -// This currently only works for STRING and SUPER data types. -func replaceExceededValues(colVal string, colKind columns.Column) string { - structOrString := colKind.KindDetails.Kind == typing.Struct.Kind || colKind.KindDetails.Kind == typing.String.Kind +func replaceExceededValues(colVal string, colKind typing.KindDetails) string { + structOrString := colKind.Kind == typing.Struct.Kind || colKind.Kind == typing.String.Kind if structOrString { maxLength := maxRedshiftLength // If the customer has specified the maximum string precision, let's use that as the max length. - if colKind.KindDetails.OptionalStringPrecision != nil { - maxLength = *colKind.KindDetails.OptionalStringPrecision + if colKind.OptionalStringPrecision != nil { + maxLength = *colKind.OptionalStringPrecision } if shouldReplace := int32(len(colVal)) > maxLength; shouldReplace { - if colKind.KindDetails.Kind == typing.Struct.Kind { + if colKind.Kind == typing.Struct.Kind { return fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker) } @@ -34,11 +31,9 @@ func replaceExceededValues(colVal string, colKind columns.Column) string { return colVal } -// CastColValStaging - takes `colVal` any and `colKind` typing.Column and converts the value into a string value -// This is necessary because CSV writers require values to in `string`. -func (s *Store) CastColValStaging(colVal any, colKind columns.Column, additionalDateFmts []string) (string, error) { +func castColValStaging(colVal any, colKind typing.KindDetails, additionalDateFmts []string) (string, error) { if colVal == nil { - if colKind.KindDetails == typing.Struct { + if colKind == typing.Struct { // Returning empty here because if it's a struct, it will go through JSON PARSE and JSON_PARSE("") = null return "", nil } diff --git a/clients/redshift/cast_test.go b/clients/redshift/cast_test.go index 75bf2d8df..fdab5d974 100644 --- a/clients/redshift/cast_test.go +++ b/clients/redshift/cast_test.go @@ -2,18 +2,11 @@ package redshift import ( "fmt" - "testing" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/stringutil" - "github.com/artie-labs/transfer/lib/db" - - "github.com/artie-labs/transfer/lib/config" - - "github.com/artie-labs/transfer/lib/typing/columns" - "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/typing" @@ -21,136 +14,53 @@ import ( ) func (r *RedshiftTestSuite) TestReplaceExceededValues() { - type _tc struct { - name string - colVal string - colKind columns.Column - expectedResult string + { + // Masked, reached the DDL limit + assert.Equal(r.T(), constants.ExceededValueMarker, replaceExceededValues(stringutil.Random(int(maxRedshiftLength)+1), typing.String)) } - - tcs := []_tc{ - { - name: "string", - colVal: stringutil.Random(int(maxRedshiftLength) + 1), - colKind: columns.Column{ - KindDetails: typing.String, - }, - expectedResult: constants.ExceededValueMarker, - }, - { - name: "string (specified string precision)", - colVal: "hello dusty", - colKind: columns.Column{ - KindDetails: typing.KindDetails{ - Kind: typing.String.Kind, - OptionalStringPrecision: ptr.ToInt32(3), - }, - }, - expectedResult: constants.ExceededValueMarker, - }, - { - name: "string - not masked", - colVal: "thisissuperlongbutnotlongenoughtogetmasked", - colKind: columns.Column{ - KindDetails: typing.String, - }, - expectedResult: "thisissuperlongbutnotlongenoughtogetmasked", - }, - { - name: "struct", - colVal: fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), - colKind: columns.Column{ - KindDetails: typing.Struct, - }, - expectedResult: fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), - }, - { - name: "string, but the data type is a SUPER", - colVal: stringutil.Random(int(maxRedshiftLength) + 1), - colKind: columns.Column{ - KindDetails: typing.Struct, - }, - expectedResult: fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), - }, - { - name: "struct - not masked", - colVal: `{"foo": "bar"}`, - colKind: columns.Column{ - KindDetails: typing.Struct, - }, - expectedResult: `{"foo": "bar"}`, - }, + { + // Masked, reached the string precision limit + stringKd := typing.KindDetails{ + Kind: typing.String.Kind, + OptionalStringPrecision: ptr.ToInt32(3), + } + + assert.Equal(r.T(), constants.ExceededValueMarker, replaceExceededValues("hello", stringKd)) } - - for _, tc := range tcs { - assert.Equal(r.T(), tc.expectedResult, replaceExceededValues(tc.colVal, tc.colKind), tc.name) + { + // Struct and masked + assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), replaceExceededValues(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct)) } -} - -type _testCase struct { - name string - colVal any - colKind columns.Column - - expectedString string - errorMessage string -} - -func evaluateTestCase(t *testing.T, store *Store, testCase _testCase) { - actualString, actualErr := store.CastColValStaging(testCase.colVal, testCase.colKind, nil) - if len(testCase.errorMessage) > 0 { - assert.ErrorContains(t, actualErr, testCase.errorMessage, testCase.name) - } else { - assert.NoError(t, actualErr, testCase.name) - assert.Equal(t, testCase.expectedString, actualString, testCase.name) + { + // Not masked + assert.Equal(r.T(), `{"foo": "bar"}`, replaceExceededValues(`{"foo": "bar"}`, typing.Struct)) + assert.Equal(r.T(), "hello world", replaceExceededValues("hello world", typing.String)) } } -func (r *RedshiftTestSuite) TestCastColValStaging_ExceededValues() { - testCases := []_testCase{ - { - name: "string", - colVal: stringutil.Random(int(maxRedshiftLength) + 1), - colKind: columns.Column{ - KindDetails: typing.String, - }, - expectedString: "__artie_exceeded_value", - }, - { - name: "string", - colVal: "thisissuperlongbutnotlongenoughtogetmasked", - colKind: columns.Column{ - KindDetails: typing.String, - }, - expectedString: "thisissuperlongbutnotlongenoughtogetmasked", - }, - { - name: "struct", - colVal: map[string]any{"foo": stringutil.Random(int(maxRedshiftLength) + 1)}, - colKind: columns.Column{ - KindDetails: typing.Struct, - }, - expectedString: `{"key":"__artie_exceeded_value"}`, - }, - { - name: "struct", - colVal: map[string]any{"foo": stringutil.Random(int(maxRedshiftLength) + 1)}, - colKind: columns.Column{ - KindDetails: typing.Struct, - }, - expectedString: `{"key":"__artie_exceeded_value"}`, - }, +func (r *RedshiftTestSuite) TestCastColValStaging() { + { + // Masked + value, err := castColValStaging(stringutil.Random(int(maxRedshiftLength)+1), typing.String, nil) + assert.NoError(r.T(), err) + assert.Equal(r.T(), constants.ExceededValueMarker, value) } - - cfg := config.Config{ - Redshift: &config.Redshift{}, + { + // Valid + value, err := castColValStaging("thisissuperlongbutnotlongenoughtogetmasked", typing.String, nil) + assert.NoError(r.T(), err) + assert.Equal(r.T(), "thisissuperlongbutnotlongenoughtogetmasked", value) } - - store := db.Store(r.fakeStore) - skipLargeRowsStore, err := LoadRedshift(cfg, &store) - assert.NoError(r.T(), err) - - for _, testCase := range testCases { - evaluateTestCase(r.T(), skipLargeRowsStore, testCase) + { + // Masked struct + value, err := castColValStaging(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct, nil) + assert.NoError(r.T(), err) + assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), value) + } + { + // Valid struct + value, err := castColValStaging(`{"foo": "bar"}`, typing.Struct, nil) + assert.NoError(r.T(), err) + assert.Equal(r.T(), `{"foo": "bar"}`, value) } } diff --git a/clients/redshift/staging.go b/clients/redshift/staging.go index e7bb19993..e8c551fc5 100644 --- a/clients/redshift/staging.go +++ b/clients/redshift/staging.go @@ -95,7 +95,7 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID for _, value := range tableData.Rows() { var row []string for _, col := range columns { - castedValue, castErr := s.CastColValStaging(value[col.Name()], col, additionalDateFmts) + castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails, additionalDateFmts) if castErr != nil { return "", castErr } diff --git a/clients/snowflake/staging.go b/clients/snowflake/staging.go index b86a1ba76..645c89ebf 100644 --- a/clients/snowflake/staging.go +++ b/clients/snowflake/staging.go @@ -14,7 +14,6 @@ import ( "github.com/artie-labs/transfer/lib/optimization" "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" - "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/values" ) @@ -41,9 +40,7 @@ func replaceExceededValues(colVal string, kindDetails typing.KindDetails) string return colVal } -// castColValStaging - takes `colVal` any and `colKind` typing.Column and converts the value into a string value -// This is necessary because CSV writers require values to in `string`. -func castColValStaging(colVal any, colKind columns.Column, additionalDateFmts []string) (string, error) { +func castColValStaging(colVal any, colKind typing.KindDetails, additionalDateFmts []string) (string, error) { if colVal == nil { // \\N needs to match NULL_IF(...) from ddl.go return `\\N`, nil @@ -54,7 +51,7 @@ func castColValStaging(colVal any, colKind columns.Column, additionalDateFmts [] return "", err } - return replaceExceededValues(value, colKind.KindDetails), nil + return replaceExceededValues(value, colKind), nil } func (s *Store) PrepareTemporaryTable(tableData *optimization.TableData, tableConfig *types.DwhTableConfig, tempTableID sql.TableIdentifier, additionalSettings types.AdditionalSettings, createTempTable bool) error { @@ -125,7 +122,7 @@ func (s *Store) writeTemporaryTableFile(tableData *optimization.TableData, newTa for _, value := range tableData.Rows() { var row []string for _, col := range columns { - castedValue, castErr := castColValStaging(value[col.Name()], col, additionalDateFmts) + castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails, additionalDateFmts) if castErr != nil { return "", castErr } diff --git a/clients/snowflake/staging_test.go b/clients/snowflake/staging_test.go index 7590aef97..2aee228fa 100644 --- a/clients/snowflake/staging_test.go +++ b/clients/snowflake/staging_test.go @@ -42,7 +42,7 @@ func (s *SnowflakeTestSuite) TestReplaceExceededValues() { func (s *SnowflakeTestSuite) TestCastColValStaging() { { // Null - value, err := castColValStaging(nil, columns.Column{KindDetails: typing.String}, nil) + value, err := castColValStaging(nil, typing.String, nil) assert.NoError(s.T(), err) assert.Equal(s.T(), `\\N`, value) } @@ -50,18 +50,18 @@ func (s *SnowflakeTestSuite) TestCastColValStaging() { // Struct field // Did not exceed lob size - value, err := castColValStaging(map[string]any{"key": "value"}, columns.Column{KindDetails: typing.Struct}, nil) + value, err := castColValStaging(map[string]any{"key": "value"}, typing.Struct, nil) assert.NoError(s.T(), err) assert.Equal(s.T(), `{"key":"value"}`, value) // Did exceed lob size - value, err = castColValStaging(map[string]any{"key": strings.Repeat("a", 16777216)}, columns.Column{KindDetails: typing.Struct}, nil) + value, err = castColValStaging(map[string]any{"key": strings.Repeat("a", 16777216)}, typing.Struct, nil) assert.NoError(s.T(), err) assert.Equal(s.T(), `{"key":"__artie_exceeded_value"}`, value) } { // String field - value, err := castColValStaging("foo", columns.Column{KindDetails: typing.String}, nil) + value, err := castColValStaging("foo", typing.String, nil) assert.NoError(s.T(), err) assert.Equal(s.T(), "foo", value) } diff --git a/lib/optimization/event_update_test.go b/lib/optimization/event_update_test.go index 4ece9e91e..f7e4322b1 100644 --- a/lib/optimization/event_update_test.go +++ b/lib/optimization/event_update_test.go @@ -24,6 +24,30 @@ func TestTableData_UpdateInMemoryColumnsFromDestination(t *testing.T) { invalidCol := columns.NewColumn("foo", typing.Invalid) assert.ErrorContains(t, tableData.MergeColumnsFromDestination(invalidCol), `column "foo" is invalid`) } + { + // If the in-memory column is a string and the destination column is Date + // We should mark the in-memory column as date and try to parse it accordingly. + tableDataCols := &columns.Columns{} + tableData := &TableData{ + inMemoryColumns: tableDataCols, + } + + tableData.AddInMemoryCol(columns.NewColumn("foo", typing.String)) + + extTime := typing.ETime + extTime.ExtendedTimeDetails = &ext.NestedKind{ + Type: ext.DateKindType, + } + + tsCol := columns.NewColumn("foo", extTime) + assert.NoError(t, tableData.MergeColumnsFromDestination(tsCol)) + + col, isOk := tableData.inMemoryColumns.GetColumn("foo") + assert.True(t, isOk) + assert.Equal(t, typing.ETime.Kind, col.KindDetails.Kind) + assert.Equal(t, ext.DateKindType, col.KindDetails.ExtendedTimeDetails.Type) + assert.Equal(t, extTime.ExtendedTimeDetails, col.KindDetails.ExtendedTimeDetails) + } { tableDataCols := &columns.Columns{} tableData := &TableData{ diff --git a/lib/typing/values/string.go b/lib/typing/values/string.go index 60502fada..6fffc1e0d 100644 --- a/lib/typing/values/string.go +++ b/lib/typing/values/string.go @@ -9,7 +9,6 @@ import ( "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/stringutil" "github.com/artie-labs/transfer/lib/typing" - "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/decimal" "github.com/artie-labs/transfer/lib/typing/ext" ) @@ -22,27 +21,27 @@ func BooleanToBit(val bool) int { } } -func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) (string, error) { +func ToString(colVal any, colKind typing.KindDetails, additionalDateFmts []string) (string, error) { if colVal == nil { return "", fmt.Errorf("colVal is nil") } - switch colKind.KindDetails.Kind { + switch colKind.Kind { case typing.ETime.Kind: extTime, err := ext.ParseFromInterface(colVal, additionalDateFmts) if err != nil { return "", fmt.Errorf("failed to cast colVal as time.Time, colVal: %v, err: %w", colVal, err) } - if colKind.KindDetails.ExtendedTimeDetails == nil { + if colKind.ExtendedTimeDetails == nil { return "", fmt.Errorf("column kind details for extended time details is null") } - if colKind.KindDetails.ExtendedTimeDetails.Type == ext.TimeKindType { + if colKind.ExtendedTimeDetails.Type == ext.TimeKindType { return extTime.String(ext.PostgresTimeFormatNoTZ), nil } - return extTime.String(colKind.KindDetails.ExtendedTimeDetails.Format), nil + return extTime.String(colKind.ExtendedTimeDetails.Format), nil case typing.String.Kind: isArray := reflect.ValueOf(colVal).Kind() == reflect.Slice _, isMap := colVal.(map[string]any) @@ -58,7 +57,7 @@ func ToString(colVal any, colKind columns.Column, additionalDateFmts []string) ( return stringutil.EscapeBackslashes(fmt.Sprint(colVal)), nil case typing.Struct.Kind: - if colKind.KindDetails == typing.Struct { + if colKind == typing.Struct { if strings.Contains(fmt.Sprint(colVal), constants.ToastUnavailableValuePlaceholder) { colVal = map[string]any{ "key": constants.ToastUnavailableValuePlaceholder, diff --git a/lib/typing/values/string_test.go b/lib/typing/values/string_test.go index 85fb782d5..c08c8b678 100644 --- a/lib/typing/values/string_test.go +++ b/lib/typing/values/string_test.go @@ -22,18 +22,18 @@ func TestBooleanToBit(t *testing.T) { func TestToString(t *testing.T) { { // Nil value - _, err := ToString(nil, columns.Column{}, nil) + _, err := ToString(nil, typing.KindDetails{}, nil) assert.ErrorContains(t, err, "colVal is nil") } { // ETime - eTimeCol := columns.NewColumn("time", typing.ETime) - _, err := ToString("2021-01-01T00:00:00Z", eTimeCol, nil) + _, err := ToString("2021-01-01T00:00:00Z", typing.ETime, nil) assert.ErrorContains(t, err, "column kind details for extended time details is null") + eTimeCol := columns.NewColumn("time", typing.ETime) eTimeCol.KindDetails.ExtendedTimeDetails = &ext.NestedKind{Type: ext.TimeKindType} // Using `string` - val, err := ToString("2021-01-01T03:52:00Z", eTimeCol, nil) + val, err := ToString("2021-01-01T03:52:00Z", eTimeCol.KindDetails, nil) assert.NoError(t, err) assert.Equal(t, "03:52:00", val) @@ -43,87 +43,87 @@ func TestToString(t *testing.T) { extendedTime := ext.NewExtendedTime(dustyBirthday, ext.DateTimeKindType, originalFmt) eTimeCol.KindDetails.ExtendedTimeDetails = &ext.NestedKind{Type: ext.DateTimeKindType} - actualValue, err := ToString(extendedTime, eTimeCol, nil) + actualValue, err := ToString(extendedTime, eTimeCol.KindDetails, nil) assert.NoError(t, err) assert.Equal(t, extendedTime.String(originalFmt), actualValue) } { // String // JSON - val, err := ToString(map[string]any{"foo": "bar"}, columns.Column{KindDetails: typing.String}, nil) + val, err := ToString(map[string]any{"foo": "bar"}, typing.String, nil) assert.NoError(t, err) assert.Equal(t, `{"foo":"bar"}`, val) // Array - val, err = ToString([]string{"foo", "bar"}, columns.Column{KindDetails: typing.String}, nil) + val, err = ToString([]string{"foo", "bar"}, typing.String, nil) assert.NoError(t, err) assert.Equal(t, `["foo","bar"]`, val) // Normal strings - val, err = ToString("foo", columns.Column{KindDetails: typing.String}, nil) + val, err = ToString("foo", typing.String, nil) assert.NoError(t, err) assert.Equal(t, "foo", val) } { // Struct - val, err := ToString(map[string]any{"foo": "bar"}, columns.Column{KindDetails: typing.Struct}, nil) + val, err := ToString(map[string]any{"foo": "bar"}, typing.Struct, nil) assert.NoError(t, err) assert.Equal(t, `{"foo":"bar"}`, val) - val, err = ToString(constants.ToastUnavailableValuePlaceholder, columns.Column{KindDetails: typing.Struct}, nil) + val, err = ToString(constants.ToastUnavailableValuePlaceholder, typing.Struct, nil) assert.NoError(t, err) assert.Equal(t, `{"key":"__debezium_unavailable_value"}`, val) } { // Array - val, err := ToString([]string{"foo", "bar"}, columns.Column{KindDetails: typing.Array}, nil) + val, err := ToString([]string{"foo", "bar"}, typing.Array, nil) assert.NoError(t, err) assert.Equal(t, `["foo","bar"]`, val) } { // Integer // Floats first. - val, err := ToString(float32(45452.999991), columns.Column{KindDetails: typing.Integer}, nil) + val, err := ToString(float32(45452.999991), typing.Integer, nil) assert.NoError(t, err) assert.Equal(t, "45453", val) - val, err = ToString(45452.999991, columns.Column{KindDetails: typing.Integer}, nil) + val, err = ToString(45452.999991, typing.Integer, nil) assert.NoError(t, err) assert.Equal(t, "45453", val) // Integer - val, err = ToString(32, columns.Column{KindDetails: typing.Integer}, nil) + val, err = ToString(32, typing.Integer, nil) assert.NoError(t, err) assert.Equal(t, "32", val) // Booleans - val, err = ToString(true, columns.Column{KindDetails: typing.Integer}, nil) + val, err = ToString(true, typing.Integer, nil) assert.NoError(t, err) assert.Equal(t, "1", val) - val, err = ToString(false, columns.Column{KindDetails: typing.Integer}, nil) + val, err = ToString(false, typing.Integer, nil) assert.NoError(t, err) assert.Equal(t, "0", val) } { // Extended Decimal // Floats - val, err := ToString(float32(123.45), columns.Column{KindDetails: typing.EDecimal}, nil) + val, err := ToString(float32(123.45), typing.EDecimal, nil) assert.NoError(t, err) assert.Equal(t, "123.45", val) - val, err = ToString(123.45, columns.Column{KindDetails: typing.EDecimal}, nil) + val, err = ToString(123.45, typing.EDecimal, nil) assert.NoError(t, err) assert.Equal(t, "123.45", val) // String - val, err = ToString("123.45", columns.Column{KindDetails: typing.EDecimal}, nil) + val, err = ToString("123.45", typing.EDecimal, nil) assert.NoError(t, err) assert.Equal(t, "123.45", val) // Decimals value := decimal.NewDecimalWithPrecision(numbers.MustParseDecimal("585692791691858.25"), 38) - val, err = ToString(value, columns.Column{KindDetails: typing.EDecimal}, nil) + val, err = ToString(value, typing.EDecimal, nil) assert.NoError(t, err) assert.Equal(t, "585692791691858.25", val) }