diff --git a/Makefile b/Makefile index bffaf9a..b23cf47 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,8 @@ fmt: -$(GO) fmt ./... test: gen-test generate - $(GO) test -v -race -coverprofile=coverage.out ./... - $(GO) test -v -race --tags=example ./example + $(GO) test -v -race -shuffle on -coverprofile=coverage.out ./... + $(GO) test -v -race -shuffle on --tags=example ./example cover: gen-test test $(GO) tool cover -html=coverage.out -o coverage.html diff --git a/example/diff_base.go b/example/diff_base.go new file mode 100644 index 0000000..860d4cc --- /dev/null +++ b/example/diff_base.go @@ -0,0 +1,20 @@ +package example + +//go:generate ../bin/go-enum --forcelower -b example + +/* +ENUM( + + B3 = 03 + B4 = 04 + B5 = 5 + B6 = 0b110 + B7 = 0b111 + B8 = 0x08 + B9 = 0x09 + B10 = 0x0B + B11 = 0x2B + +) +*/ +type DiffBase int diff --git a/example/diff_base_enum.go b/example/diff_base_enum.go new file mode 100644 index 0000000..050758d --- /dev/null +++ b/example/diff_base_enum.go @@ -0,0 +1,87 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: example +// Revision: example +// Build Date: example +// Built By: example + +//go:build example +// +build example + +package example + +import ( + "errors" + "fmt" +) + +const ( + // DiffBaseB3 is a DiffBase of type B3. + DiffBaseB3 DiffBase = iota + 3 + // DiffBaseB4 is a DiffBase of type B4. + DiffBaseB4 + // DiffBaseB5 is a DiffBase of type B5. + DiffBaseB5 + // DiffBaseB6 is a DiffBase of type B6. + DiffBaseB6 + // DiffBaseB7 is a DiffBase of type B7. + DiffBaseB7 + // DiffBaseB8 is a DiffBase of type B8. + DiffBaseB8 + // DiffBaseB9 is a DiffBase of type B9. + DiffBaseB9 + // DiffBaseB10 is a DiffBase of type B10. + DiffBaseB10 DiffBase = iota + 4 + // DiffBaseB11 is a DiffBase of type B11. + DiffBaseB11 DiffBase = iota + 35 +) + +var ErrInvalidDiffBase = errors.New("not a valid DiffBase") + +const _DiffBaseName = "b3b4b5b6b7b8b9b10b11" + +var _DiffBaseMap = map[DiffBase]string{ + DiffBaseB3: _DiffBaseName[0:2], + DiffBaseB4: _DiffBaseName[2:4], + DiffBaseB5: _DiffBaseName[4:6], + DiffBaseB6: _DiffBaseName[6:8], + DiffBaseB7: _DiffBaseName[8:10], + DiffBaseB8: _DiffBaseName[10:12], + DiffBaseB9: _DiffBaseName[12:14], + DiffBaseB10: _DiffBaseName[14:17], + DiffBaseB11: _DiffBaseName[17:20], +} + +// String implements the Stringer interface. +func (x DiffBase) String() string { + if str, ok := _DiffBaseMap[x]; ok { + return str + } + return fmt.Sprintf("DiffBase(%d)", x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x DiffBase) IsValid() bool { + _, ok := _DiffBaseMap[x] + return ok +} + +var _DiffBaseValue = map[string]DiffBase{ + _DiffBaseName[0:2]: DiffBaseB3, + _DiffBaseName[2:4]: DiffBaseB4, + _DiffBaseName[4:6]: DiffBaseB5, + _DiffBaseName[6:8]: DiffBaseB6, + _DiffBaseName[8:10]: DiffBaseB7, + _DiffBaseName[10:12]: DiffBaseB8, + _DiffBaseName[12:14]: DiffBaseB9, + _DiffBaseName[14:17]: DiffBaseB10, + _DiffBaseName[17:20]: DiffBaseB11, +} + +// ParseDiffBase attempts to convert a string to a DiffBase. +func ParseDiffBase(name string) (DiffBase, error) { + if x, ok := _DiffBaseValue[name]; ok { + return x, nil + } + return DiffBase(0), fmt.Errorf("%s is %w", name, ErrInvalidDiffBase) +} diff --git a/example/diff_base_test.go b/example/diff_base_test.go new file mode 100644 index 0000000..9f5625b --- /dev/null +++ b/example/diff_base_test.go @@ -0,0 +1,60 @@ +//go:build example +// +build example + +package example + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDiffBase(t *testing.T) { + tests := map[string]struct { + actual int + expected DiffBase + }{ + "DiffBaseB3": { + actual: 3, + expected: DiffBaseB3, + }, + "DiffBaseB4": { + actual: 4, + expected: DiffBaseB4, + }, + "DiffBaseB5": { + actual: 5, + expected: DiffBaseB5, + }, + "DiffBaseB6": { + actual: 6, + expected: DiffBaseB6, + }, + "DiffBaseB7": { + actual: 7, + expected: DiffBaseB7, + }, + "DiffBaseB8": { + actual: 8, + expected: DiffBaseB8, + }, + "DiffBaseB9": { + actual: 9, + expected: DiffBaseB9, + }, + "DiffBaseB10": { + actual: 11, + expected: DiffBaseB10, + }, + "DiffBaseB11": { + actual: 43, + expected: DiffBaseB11, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + assert.Equal(t, int(tc.expected), tc.actual) + }) + } +} diff --git a/generator/generator.go b/generator/generator.go index 3b9fef4..35e6094 100644 --- a/generator/generator.go +++ b/generator/generator.go @@ -25,8 +25,6 @@ const ( parseCommentPrefix = `//` ) -var replacementNames = map[string]string{} - // Generator is responsible for generating validation files for the given in a go source file. type Generator struct { Version string @@ -56,6 +54,7 @@ type Generator struct { forceUpper bool noComments bool buildTags []string + replacementNames map[string]string } // Enum holds data for a discovered enum in the parsed source @@ -90,6 +89,7 @@ func NewGenerator() *Generator { t: template.New("generator"), fileSet: token.NewFileSet(), noPrefix: false, + replacementNames: map[string]string{}, } funcs := sprig.TxtFuncMap() @@ -224,12 +224,21 @@ func (g *Generator) WithBuildTags(tags ...string) *Generator { return g } +// WithAliases will set up aliases for the generator. +func (g *Generator) WithAliases(aliases map[string]string) *Generator { + if aliases == nil { + return g + } + g.replacementNames = aliases + return g +} + func (g *Generator) anySQLEnabled() bool { return g.sql || g.sqlNullStr || g.sqlint || g.sqlNullInt } // ParseAliases is used to add aliases to replace during name sanitization. -func ParseAliases(aliases []string) error { +func ParseAliases(aliases []string) (map[string]string, error) { aliasMap := map[string]string{} for _, str := range aliases { @@ -237,17 +246,13 @@ func ParseAliases(aliases []string) error { for _, kvp := range kvps { parts := strings.Split(kvp, ":") if len(parts) != 2 { - return fmt.Errorf("invalid formatted alias entry %q, must be in the format \"key:value\"", kvp) + return nil, fmt.Errorf("invalid formatted alias entry %q, must be in the format \"key:value\"", kvp) } aliasMap[parts[0]] = parts[1] } } - for k, v := range aliasMap { - replacementNames[k] = v - } - - return nil + return aliasMap, nil } // WithTemplates is used to provide the filenames of additional templates. @@ -438,7 +443,7 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { valueStr = dataVal rawName = value[:equalIndex] if enum.Type == "string" { - if parsed, err := strconv.ParseInt(dataVal, 10, 64); err == nil { + if parsed, err := strconv.ParseInt(dataVal, 0, 64); err == nil { data = parsed valueStr = rawName } @@ -446,7 +451,7 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { valueStr = trimQuotes(dataVal) } } else if unsigned { - newData, err := strconv.ParseUint(dataVal, 10, 64) + newData, err := strconv.ParseUint(dataVal, 0, 64) if err != nil { err = fmt.Errorf("failed parsing the data part of enum value '%s': %w", value, err) fmt.Println(err) @@ -454,7 +459,7 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { } data = newData } else { - newData, err := strconv.ParseInt(dataVal, 10, 64) + newData, err := strconv.ParseInt(dataVal, 0, 64) if err != nil { err = fmt.Errorf("failed parsing the data part of enum value '%s': %w", value, err) fmt.Println(err) @@ -473,7 +478,7 @@ func (g *Generator) parseEnum(ts *ast.TypeSpec) (*Enum, error) { prefixedName := name if name != skipHolder { prefixedName = enum.Prefix + name - prefixedName = sanitizeValue(prefixedName) + prefixedName = g.sanitizeValue(prefixedName) if !g.leaveSnakeCase { prefixedName = snakeToCamelCase(prefixedName) } @@ -526,14 +531,14 @@ func unescapeComment(comment string) string { // identifier syntax as described here: https://golang.org/ref/spec#Identifiers // identifier = letter { letter | unicode_digit } // where letter can be unicode_letter or '_' -func sanitizeValue(value string) string { +func (g *Generator) sanitizeValue(value string) string { // Keep skip value holders if value == skipHolder { return skipHolder } replacedValue := value - for k, v := range replacementNames { + for k, v := range g.replacementNames { replacedValue = strings.ReplaceAll(replacedValue, k, v) } diff --git a/generator/generator_1.18_test.go b/generator/generator_1.18_test.go index fd536b3..1da1d71 100644 --- a/generator/generator_1.18_test.go +++ b/generator/generator_1.18_test.go @@ -217,10 +217,7 @@ func Test118AliasParsing(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - defer func() { - replacementNames = map[string]string{} - }() - err := ParseAliases(tc.input) + replacementNames, err := ParseAliases(tc.input) if tc.err != nil { require.Error(t, err) require.EqualError(t, err, tc.err.Error()) @@ -306,9 +303,11 @@ func Test118Aliasing(t *testing.T) { // ENUM(a,b,CDEF) with some extra text type Animal int ` + aliases, err := ParseAliases([]string{"CDEF:C"}) + require.NoError(t, err) g := NewGenerator(). - WithoutSnakeToCamel() - _ = ParseAliases([]string{"CDEF:C"}) + WithoutSnakeToCamel(). + WithAliases(aliases) f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) assert.Nil(t, err, "Error parsing no struct input") diff --git a/generator/generator_test.go b/generator/generator_test.go index 3732a0b..fcae16c 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -240,8 +240,7 @@ func TestAliasParsing(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - replacementNames = map[string]string{} - err := ParseAliases(tc.input) + replacementNames, err := ParseAliases(tc.input) if tc.err != nil { require.Error(t, err) require.EqualError(t, err, tc.err.Error()) @@ -327,9 +326,11 @@ func TestAliasing(t *testing.T) { // ENUM(a,b,CDEF) with some extra text type Animal int ` + aliases, err := ParseAliases([]string{"CDEF:C"}) + require.NoError(t, err) g := NewGenerator(). - WithoutSnakeToCamel() - _ = ParseAliases([]string{"CDEF:C"}) + WithoutSnakeToCamel(). + WithAliases(aliases) f, err := parser.ParseFile(g.fileSet, "TestRequiredErrors", input, parser.ParseComments) assert.Nil(t, err, "Error parsing no struct input") diff --git a/main.go b/main.go index d57e3b7..10c166e 100644 --- a/main.go +++ b/main.go @@ -176,7 +176,8 @@ func main() { }, }, Action: func(ctx *cli.Context) error { - if err := generator.ParseAliases(argv.Aliases.Value()); err != nil { + aliases, err := generator.ParseAliases(argv.Aliases.Value()) + if err != nil { return err } for _, fileOption := range argv.FileNames.Value() { @@ -188,6 +189,7 @@ func main() { g.BuiltBy = builtBy g.WithBuildTags(argv.BuildTags.Value()...) + g.WithAliases(aliases) if argv.NoPrefix { g.WithNoPrefix()