Skip to content

Commit

Permalink
Refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tang8330 committed Oct 7, 2024
1 parent 4a5ff3f commit 2036650
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 60 deletions.
33 changes: 23 additions & 10 deletions clients/redshift/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,59 @@ import (
"github.com/artie-labs/transfer/lib/typing/values"
)

type Result struct {
Value string
// NewLength - If the value exceeded the maximum length, this will be the new length of the value.
// This is only applicable if [expandStringPrecision] is enabled.
NewLength int32
}

const maxRedshiftLength int32 = 65535

func replaceExceededValues(colVal string, colKind typing.KindDetails, truncateExceededValue bool, expandStringPrecision bool) string {
func replaceExceededValues(colVal string, colKind typing.KindDetails, truncateExceededValue bool, expandStringPrecision bool) Result {
if colKind.Kind == typing.Struct.Kind || colKind.Kind == typing.String.Kind {
maxLength := maxRedshiftLength
// If the customer has specified the maximum string precision, let's use that as the max length.
if colKind.OptionalStringPrecision != nil {
maxLength = *colKind.OptionalStringPrecision
}

if shouldReplace := int32(len(colVal)) > maxLength; shouldReplace {
colValLength := int32(len(colVal))
// If [expandStringPrecision] is enabled and the value is greater than the maximum length, but less than the maximum Redshift length.
if expandStringPrecision && colValLength > maxLength && colValLength < maxRedshiftLength {
return Result{Value: colVal, NewLength: colValLength}
}

if shouldReplace := colValLength > maxLength; shouldReplace {
if colKind.Kind == typing.Struct.Kind {
return fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker)
return Result{Value: fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker)}
}

if truncateExceededValue {
return colVal[:maxLength]
return Result{Value: colVal[:maxLength]}
} else {
return constants.ExceededValueMarker
return Result{Value: constants.ExceededValueMarker}
}
}
}

return colVal
return Result{Value: colVal}
}

func castColValStaging(colVal any, colKind typing.KindDetails, truncateExceededValue bool, expandStringPrecision bool) (string, error) {
func castColValStaging(colVal any, colKind typing.KindDetails, truncateExceededValue bool, expandStringPrecision bool) (Result, error) {
if colVal == nil {
if colKind == typing.Struct {
// Returning empty here because if it's a struct, it will go through JSON PARSE and JSON_PARSE("") = null
return "", nil
return Result{}, nil
}

// This matches the COPY clause for NULL terminator.
return `\N`, nil
return Result{Value: `\N`}, nil
}

colValString, err := values.ToString(colVal, colKind)
if err != nil {
return "", err
return Result{}, err
}

// Checks for DDL overflow needs to be done at the end in case there are any conversions that need to be done.
Expand Down
86 changes: 40 additions & 46 deletions clients/redshift/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ func (r *RedshiftTestSuite) TestReplaceExceededValues() {
// Irrelevant data type
{
// Integer
assert.Equal(r.T(), "123", replaceExceededValues("123", typing.Integer, false, false))
result := replaceExceededValues("123", typing.Integer, false, false)
assert.Equal(r.T(), "123", result.Value)
assert.Zero(r.T(), result.NewLength)
}
{
// Returns the full value since it's not a struct or string
// This is invalid and should not happen, but it's here to ensure we're only checking for structs and strings.
value := stringutil.Random(int(maxRedshiftLength + 1))
assert.Equal(r.T(), value, replaceExceededValues(value, typing.Integer, false, false))
result := replaceExceededValues(value, typing.Integer, false, false)
assert.Equal(r.T(), value, result.Value)
assert.Zero(r.T(), result.NewLength)
}
}
{
Expand All @@ -31,7 +35,9 @@ func (r *RedshiftTestSuite) TestReplaceExceededValues() {
// String
{
// TruncateExceededValue = false
assert.Equal(r.T(), constants.ExceededValueMarker, replaceExceededValues(stringutil.Random(int(maxRedshiftLength)+1), typing.String, false, false))
result := replaceExceededValues(stringutil.Random(int(maxRedshiftLength)+1), typing.String, false, false)
assert.Equal(r.T(), constants.ExceededValueMarker, result.Value)
assert.Zero(r.T(), result.NewLength)
}
{
// TruncateExceededValue = false, string precision specified
Expand All @@ -40,12 +46,16 @@ func (r *RedshiftTestSuite) TestReplaceExceededValues() {
OptionalStringPrecision: typing.ToPtr(int32(3)),
}

assert.Equal(r.T(), constants.ExceededValueMarker, replaceExceededValues("hello", stringKd, false, false))
result := replaceExceededValues("hello", stringKd, false, false)
assert.Equal(r.T(), constants.ExceededValueMarker, result.Value)
assert.Zero(r.T(), result.NewLength)
}
{
// TruncateExceededValue = true
superLongString := stringutil.Random(int(maxRedshiftLength) + 1)
assert.Equal(r.T(), superLongString[:maxRedshiftLength], replaceExceededValues(superLongString, typing.String, true, false))
result := replaceExceededValues(superLongString, typing.String, true, false)
assert.Equal(r.T(), superLongString[:maxRedshiftLength], result.Value)
assert.Zero(r.T(), result.NewLength)
}
{
// TruncateExceededValue = true, string precision specified
Expand All @@ -54,67 +64,51 @@ func (r *RedshiftTestSuite) TestReplaceExceededValues() {
OptionalStringPrecision: typing.ToPtr(int32(3)),
}

assert.Equal(r.T(), "hel", replaceExceededValues("hello", stringKd, true, false))
result := replaceExceededValues("hello", stringKd, true, false)
assert.Equal(r.T(), "hel", result.Value)
assert.Zero(r.T(), result.NewLength)
}
}
{
// Struct and masked
assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), replaceExceededValues(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct, false, false))
result := replaceExceededValues(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct, false, false)
assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), result.Value)
assert.Zero(r.T(), result.NewLength)
}
}
{
// Valid
{
// Not masked
assert.Equal(r.T(), `{"foo": "bar"}`, replaceExceededValues(`{"foo": "bar"}`, typing.Struct, false, false))
assert.Equal(r.T(), "hello world", replaceExceededValues("hello world", typing.String, false, false))
{
result := replaceExceededValues(`{"foo": "bar"}`, typing.Struct, false, false)
assert.Equal(r.T(), `{"foo": "bar"}`, result.Value)
assert.Zero(r.T(), result.NewLength)
}
{
result := replaceExceededValues("hello world", typing.String, false, false)
assert.Equal(r.T(), "hello world", result.Value)
assert.Zero(r.T(), result.NewLength)
}
}
}
}
}

func (r *RedshiftTestSuite) TestCastColValStaging() {
{
// nil
{
// expandStringPrecision = false
// Exceeded
{
// String
{
// TruncateExceededValue = false
value, err := castColValStaging(stringutil.Random(int(maxRedshiftLength)+1), typing.String, false, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), constants.ExceededValueMarker, value)
}
{
// TruncateExceededValue = true
value := stringutil.Random(int(maxRedshiftLength) + 1)
value, err := castColValStaging(value, typing.String, true, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), value[:maxRedshiftLength], value)
}
}
{
// Masked struct
value, err := castColValStaging(fmt.Sprintf(`{"foo": "%s"}`, stringutil.Random(int(maxRedshiftLength)+1)), typing.Struct, false, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), fmt.Sprintf(`{"key":"%s"}`, constants.ExceededValueMarker), value)
}
// Struct
result, err := castColValStaging(nil, typing.Struct, false, false)
assert.NoError(r.T(), err)
assert.Empty(r.T(), result.Value)
}
{
// Not exceeded
{
// Valid string
value, err := castColValStaging("thisissuperlongbutnotlongenoughtogetmasked", typing.String, false, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), "thisissuperlongbutnotlongenoughtogetmasked", value)
}
{
// Valid struct
value, err := castColValStaging(`{"foo": "bar"}`, typing.Struct, false, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), `{"foo": "bar"}`, value)
}
// Not struct
result, err := castColValStaging(nil, typing.String, false, false)
assert.NoError(r.T(), err)
assert.Equal(r.T(), `\N`, result.Value)
}
}
}
8 changes: 4 additions & 4 deletions clients/redshift/staging.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ func (s *Store) loadTemporaryTable(tableData *optimization.TableData, newTableID
for _, value := range tableData.Rows() {
var row []string
for _, col := range columns {
castedValue, castErr := castColValStaging(value[col.Name()], col.KindDetails, s.config.SharedDestinationSettings.TruncateExceededValues, s.config.SharedDestinationSettings.ExpandStringPrecision)
if castErr != nil {
return "", castErr
result, err := castColValStaging(value[col.Name()], col.KindDetails, s.config.SharedDestinationSettings.TruncateExceededValues, s.config.SharedDestinationSettings.ExpandStringPrecision)
if err != nil {
return "", err
}

row = append(row, castedValue)
row = append(row, result.Value)
}

if err = writer.Write(row); err != nil {
Expand Down

0 comments on commit 2036650

Please sign in to comment.