From d12e091ae60047571a03ceb58a1e94a7e3f2aa90 Mon Sep 17 00:00:00 2001 From: Robin Tang Date: Tue, 10 Dec 2024 14:12:01 -0800 Subject: [PATCH] [MySQL] Support ENUM and SET (#584) --- lib/mysql/schema/convert.go | 95 +++++++++++++++++++++++++++++++- lib/mysql/schema/convert_test.go | 74 +++++++++++++++++++++---- lib/mysql/schema/schema.go | 25 ++++++--- lib/mysql/schema/schema_test.go | 51 +++++++++++++++++ 4 files changed, 223 insertions(+), 22 deletions(-) diff --git a/lib/mysql/schema/convert.go b/lib/mysql/schema/convert.go index 089ac19d..a7e34e2f 100644 --- a/lib/mysql/schema/convert.go +++ b/lib/mysql/schema/convert.go @@ -9,6 +9,38 @@ import ( "time" ) +// asSet will parse values that come from streaming and snapshot processes +// - Snapshot will emit the value as a string, and we'll return it directly. +// - Streaming will emit the value as an int64 (which is the bitset), and we'll convert it to a string. +func asSet(val any, opts []string) (string, error) { + if castedValue, ok := val.(int64); ok { + var out []string + for i, opt := range opts { + if castedValue&(1< 0 { + out = append(out, opt) + } + } + + return strings.Join(out, ","), nil + } + + return asString(val) +} + +// asEnum will parse values that come from streaming and snapshot processes +// - Snapshot will emit the value as a string, and we'll return it directly. +// - Streaming will emit the value as an int64 (which is the index), and we'll convert it to a string. +func asEnum(val any, opts []string) (string, error) { + if castedValue, ok := val.(int64); ok { + if int(castedValue) >= len(opts) { + return "", fmt.Errorf("enum value %d not in range [0, %d]", castedValue, len(opts)-1) + } + return opts[castedValue], nil + } + + return asString(val) +} + func asInt64(val any) (int64, error) { switch castedValue := val.(type) { case int64: @@ -133,6 +165,18 @@ func ConvertValue(value any, colType DataType, opts *Opts) (any, error) { return nil, err } return timeValue, nil + case Enum: + if opts == nil { + return nil, fmt.Errorf("enum column has no options") + } + + return asEnum(value, opts.EnumValues) + case Set: + if opts == nil { + return nil, fmt.Errorf("set column has no options") + } + + return asSet(value, opts.EnumValues) case Decimal, Time, Char, @@ -141,8 +185,6 @@ func ConvertValue(value any, colType DataType, opts *Opts) (any, error) { TinyText, MediumText, LongText, - Enum, - Set, JSON: // Types that we expect as a byte array that will be converted to strings return asString(value) @@ -240,3 +282,52 @@ func hasNonStrictModeInvalidDate(d string) bool { } return false } + +func peek(s string, position uint) (byte, bool) { + if len(s) <= int(position) { + return 0, false + } + + return s[position], true +} + +// parseEnumValues will parse the metadata string for an ENUM or SET column and return the values. +// Note: This was not implemented using Go's CSV stdlib as we cannot modify the quote char from `"` to `'`. Ref: https://github.com/golang/go/issues/8458 +func parseEnumValues(metadata string) ([]string, error) { + var quoteByte byte = '\'' + var result []string + var current strings.Builder + var inQuotes bool + + for i := 0; i < len(metadata); i++ { + char := metadata[i] + switch char { + case quoteByte: + if inQuotes { + if nextChar, ok := peek(metadata, uint(i+1)); ok && nextChar == quoteByte { + current.WriteByte(quoteByte) + i++ + } else { + inQuotes = false + } + } else { + inQuotes = true + } + case ',': + if inQuotes { + current.WriteByte(char) + } else { + result = append(result, current.String()) + current.Reset() + } + default: + current.WriteByte(char) + } + } + + if current.Len() > 0 { + result = append(result, current.String()) + } + + return result, nil +} diff --git a/lib/mysql/schema/convert_test.go b/lib/mysql/schema/convert_test.go index 0e84c79e..e5e64c57 100644 --- a/lib/mysql/schema/convert_test.go +++ b/lib/mysql/schema/convert_test.go @@ -58,6 +58,68 @@ func TestConvertValue(t *testing.T) { assert.ErrorContains(t, err, "value overflows float32") } } + { + // Set + { + // Passed in as a string + value, err := ConvertValue("dogs,cats,mouse", Set, &Opts{EnumValues: []string{"dogs", "cats", "mouse"}}) + assert.NoError(t, err) + assert.Equal(t, "dogs,cats,mouse", value) + } + { + // Passed in as int64 + opts := &Opts{EnumValues: []string{"dogs", "cats", "mouse"}} + { + value, err := ConvertValue(int64(0), Set, opts) + assert.NoError(t, err) + assert.Equal(t, "", value) + } + { + value, err := ConvertValue(int64(1), Set, opts) + assert.NoError(t, err) + assert.Equal(t, "dogs", value) + } + { + value, err := ConvertValue(int64(2), Set, opts) + assert.NoError(t, err) + assert.Equal(t, "cats", value) + } + { + value, err := ConvertValue(int64(5), Set, opts) + assert.NoError(t, err) + assert.Equal(t, "dogs,mouse", value) + } + { + value, err := ConvertValue(int64(7), Set, opts) + assert.NoError(t, err) + assert.Equal(t, "dogs,cats,mouse", value) + } + } + } + { + // Enum + { + // Passed in as a string + value, err := ConvertValue("dogs", Enum, &Opts{EnumValues: []string{"dogs", "cats", "mouse"}}) + assert.NoError(t, err) + assert.Equal(t, "dogs", value) + } + { + // Passed in as int64 + opts := &Opts{EnumValues: []string{"dogs", "cats", "mouse"}} + { + // Valid + value, err := ConvertValue(int64(0), Enum, opts) + assert.NoError(t, err) + assert.Equal(t, "dogs", value) + } + { + // Invalid + _, err := ConvertValue(int64(3), Enum, opts) + assert.ErrorContains(t, err, "enum value 3 not in range [0, 2]") + } + } + } tests := []struct { name string @@ -263,18 +325,6 @@ func TestConvertValue(t *testing.T) { value: []byte("hello world"), expected: "hello world", }, - { - name: "enum", - dataType: Enum, - value: []byte("orange"), - expected: "orange", - }, - { - name: "set", - dataType: Set, - value: []byte("orange"), - expected: "orange", - }, { name: "json", dataType: JSON, diff --git a/lib/mysql/schema/schema.go b/lib/mysql/schema/schema.go index 8a634633..e7cd644e 100644 --- a/lib/mysql/schema/schema.go +++ b/lib/mysql/schema/schema.go @@ -57,9 +57,10 @@ const ( ) type Opts struct { - Scale *uint16 - Precision *int - Size *int + Scale *uint16 + Precision *int + Size *int + EnumValues []string } type Column = column.Column[DataType, Opts] @@ -130,7 +131,7 @@ func ParseColumnDataType(originalS string) (DataType, *Opts, error) { // Make sure the format looks like int (n) unsigned return -1, nil, fmt.Errorf("malformed data type: %q", originalS) } - metadata = s[parenIndex+1 : len(s)-1] + metadata = originalS[parenIndex+1 : len(s)-1] s = s[:parenIndex] } @@ -223,10 +224,18 @@ func ParseColumnDataType(originalS string) (DataType, *Opts, error) { return MediumText, nil, nil case "longtext": return LongText, nil, nil - case "enum": - return Enum, nil, nil - case "set": - return Set, nil, nil + case "enum", "set": + dataType := Enum + if s == "set" { + dataType = Set + } + + values, err := parseEnumValues(metadata) + if err != nil { + return -1, nil, fmt.Errorf("failed to parse enum values: %w", err) + } + + return dataType, &Opts{EnumValues: values}, nil case "json": return JSON, nil, nil case "point": diff --git a/lib/mysql/schema/schema_test.go b/lib/mysql/schema/schema_test.go index 93350ac4..7e6828d3 100644 --- a/lib/mysql/schema/schema_test.go +++ b/lib/mysql/schema/schema_test.go @@ -92,6 +92,57 @@ func TestParseColumnDataType(t *testing.T) { assert.Equal(t, Decimal, dataType) assert.Equal(t, &Opts{Precision: typing.ToPtr(5), Scale: typing.ToPtr(uint16(2))}, opts) } + { + // Enum + { + // No need to escape + dataType, opts, err := ParseColumnDataType("enum('a','b','c')") + assert.NoError(t, err) + assert.Equal(t, Enum, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"a", "b", "c"}}, opts) + } + { + // No need to escape, testing for capitalization + dataType, opts, err := ParseColumnDataType("ENUM('A','B','C')") + assert.NoError(t, err) + assert.Equal(t, Enum, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"A", "B", "C"}}, opts) + } + { + // Need to escape + dataType, opts, err := ParseColumnDataType(`enum('newline\n','tab ','backslash\\','quote''s')`) + assert.NoError(t, err) + assert.Equal(t, Enum, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"newline\\n", "tab\t", "backslash\\\\", "quote's"}}, opts) + assert.Equal(t, &Opts{EnumValues: []string{"newline\\n", `tab `, `backslash\\`, "quote's"}}, opts) + + } + { + // Need to escape another one + dataType, opts, err := ParseColumnDataType("ENUM('active','inactive','on hold','approved by ''manager''','needs \\\\review')") + assert.NoError(t, err) + assert.Equal(t, Enum, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"active", "inactive", "on hold", "approved by 'manager'", "needs \\\\review"}}, opts) + assert.Equal(t, &Opts{EnumValues: []string{"active", "inactive", "on hold", `approved by 'manager'`, `needs \\review`}}, opts) + } + } + { + // Set + { + // No need to escape + dataType, opts, err := ParseColumnDataType("set('a','b','c')") + assert.NoError(t, err) + assert.Equal(t, Set, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"a", "b", "c"}}, opts) + } + { + // No need to escape, testing for capitalization + dataType, opts, err := ParseColumnDataType("SET('A','B','C')") + assert.NoError(t, err) + assert.Equal(t, Set, dataType) + assert.Equal(t, &Opts{EnumValues: []string{"A", "B", "C"}}, opts) + } + } { // Blob for _, blob := range []string{"blob", "tinyblob", "mediumblob", "longblob"} {