diff --git a/CHANGELOG.md b/CHANGELOG.md index c2b8b07b..e5967871 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - A new `river migrate-list` command is available which lists available migrations and which version a target database is migrated to. [PR #534](https://github.com/riverqueue/river/pull/534). - `river version` or `river --version` now prints River version information. [PR #537](https://github.com/riverqueue/river/pull/537). +### Changed + +⚠️ Version 0.12.0 has a small breaking change in `rivermigrate`. As before, we try never to make breaking changes, but this one was deemed worth it because it's quite small and may help avoid panics. + +- **Breaking change:** `rivermigrate.New` now returns a possible error along with a migrator. An error may be returned, for example, when a migration line is configured that doesn't exist. [PR #558](https://github.com/riverqueue/river/pull/558). + + ```go + # before + migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + + # after + migrator, err := rivermigrate.New(riverpgxv5.New(dbPool), nil) + if err != nil { + // handle error + } + ``` + +- The migrator now produces a better error in case of a non-existent migration line including suggestions for known migration lines that are similar in name to the invalid one. [PR #558](https://github.com/riverqueue/river/pull/558). + ## Fixed - Fixed a panic that'd occur if `StopAndCancel` was invoked before a client was started. [PR #557](https://github.com/riverqueue/river/pull/557). diff --git a/cmd/river/rivercli/command.go b/cmd/river/rivercli/command.go index 57e30691..d8c6a09b 100644 --- a/cmd/river/rivercli/command.go +++ b/cmd/river/rivercli/command.go @@ -56,7 +56,7 @@ type CommandBase struct { Out io.Writer GetBenchmarker func() BenchmarkerInterface - GetMigrator func(config *rivermigrate.Config) MigratorInterface + GetMigrator func(config *rivermigrate.Config) (MigratorInterface, error) } func (b *CommandBase) GetCommandBase() *CommandBase { return b } @@ -94,7 +94,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle // command doesn't take one. case bundle.DatabaseURL == nil: commandBase.GetBenchmarker = func() BenchmarkerInterface { panic("databaseURL was not set") } - commandBase.GetMigrator = func(config *rivermigrate.Config) MigratorInterface { panic("databaseURL was not set") } + commandBase.GetMigrator = func(config *rivermigrate.Config) (MigratorInterface, error) { panic("databaseURL was not set") } case strings.HasPrefix(*bundle.DatabaseURL, uriScheme) || strings.HasPrefix(*bundle.DatabaseURL, uriSchemeAlias): @@ -107,7 +107,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle driver := bundle.DriverProcurer.ProcurePgxV5(dbPool) commandBase.GetBenchmarker = func() BenchmarkerInterface { return riverbench.NewBenchmarker(driver, commandBase.Logger) } - commandBase.GetMigrator = func(config *rivermigrate.Config) MigratorInterface { return rivermigrate.New(driver, config) } + commandBase.GetMigrator = func(config *rivermigrate.Config) (MigratorInterface, error) { return rivermigrate.New(driver, config) } default: return false, fmt.Errorf( diff --git a/cmd/river/rivercli/river_cli.go b/cmd/river/rivercli/river_cli.go index 35d4a130..66728be1 100644 --- a/cmd/river/rivercli/river_cli.go +++ b/cmd/river/rivercli/river_cli.go @@ -387,7 +387,12 @@ type migrateDown struct { } func (c *migrateDown) Run(ctx context.Context, opts *migrateOpts) (bool, error) { - res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + migrator, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + if err != nil { + return false, err + } + + res, err := migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ DryRun: opts.DryRun, MaxSteps: opts.MaxSteps, TargetVersion: opts.TargetVersion, @@ -470,7 +475,10 @@ func (c *migrateGet) Run(_ context.Context, opts *migrateGetOpts) (bool, error) // other databases is added in the future. Unlike other migrate commands, // this one doesn't take a `--database-url`, so we'd need a way of // detecting the database type. - migrator := rivermigrate.New(c.DriverProcurer.ProcurePgxV5(nil), &rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + migrator, err := rivermigrate.New(c.DriverProcurer.ProcurePgxV5(nil), &rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + if err != nil { + return false, err + } var migrations []rivermigrate.Migration if opts.All { @@ -534,7 +542,10 @@ type migrateList struct { } func (c *migrateList) Run(ctx context.Context, opts *migrateListOpts) (bool, error) { - migrator := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + migrator, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + if err != nil { + return false, err + } allMigrations := migrator.AllVersions() @@ -568,7 +579,12 @@ type migrateUp struct { } func (c *migrateUp) Run(ctx context.Context, opts *migrateOpts) (bool, error) { - res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + migrator, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + if err != nil { + return false, err + } + + res, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ DryRun: opts.DryRun, MaxSteps: opts.MaxSteps, TargetVersion: opts.TargetVersion, @@ -600,7 +616,12 @@ type validate struct { } func (c *validate) Run(ctx context.Context, opts *validateOpts) (bool, error) { - res, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}).Validate(ctx) + migrator, err := c.GetMigrator(&rivermigrate.Config{Line: opts.Line, Logger: c.Logger}) + if err != nil { + return false, err + } + + res, err := migrator.Validate(ctx) if err != nil { return false, err } diff --git a/cmd/river/rivercli/river_cli_test.go b/cmd/river/rivercli/river_cli_test.go index 2602ad02..7e08d210 100644 --- a/cmd/river/rivercli/river_cli_test.go +++ b/cmd/river/rivercli/river_cli_test.go @@ -177,7 +177,7 @@ func TestMigrateList(t *testing.T) { migratorStub.allVersionsStub = func() []rivermigrate.Migration { return testMigrationAll } migratorStub.existingVersionsStub = func(ctx context.Context) ([]rivermigrate.Migration, error) { return nil, nil } - cmd.GetCommandBase().GetMigrator = func(config *rivermigrate.Config) MigratorInterface { return migratorStub } + cmd.GetCommandBase().GetMigrator = func(config *rivermigrate.Config) (MigratorInterface, error) { return migratorStub, nil } return cmd, &testBundle{ out: out, @@ -274,7 +274,7 @@ func withCommandBase[TCommand Command[TOpts], TOpts CommandOpts](t *testing.T, c Logger: riversharedtest.Logger(t), Out: &out, - GetMigrator: func(config *rivermigrate.Config) MigratorInterface { return &MigratorStub{} }, + GetMigrator: func(config *rivermigrate.Config) (MigratorInterface, error) { return &MigratorStub{}, nil }, }) return cmd, &out } diff --git a/internal/cmd/testdbman/main.go b/internal/cmd/testdbman/main.go index 14996802..61660c3c 100644 --- a/internal/cmd/testdbman/main.go +++ b/internal/cmd/testdbman/main.go @@ -126,7 +126,11 @@ func createTestDatabases(ctx context.Context, out io.Writer) error { } defer dbPool.Close() - migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + migrator, err := rivermigrate.New(riverpgxv5.New(dbPool), nil) + if err != nil { + return err + } + if _, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{}); err != nil { return err } diff --git a/rivermigrate/example_migrate_database_sql_test.go b/rivermigrate/example_migrate_database_sql_test.go index bc91a469..9fd04cd7 100644 --- a/rivermigrate/example_migrate_database_sql_test.go +++ b/rivermigrate/example_migrate_database_sql_test.go @@ -30,7 +30,10 @@ func Example_migrateDatabaseSQL() { } defer tx.Rollback() - migrator := rivermigrate.New(riverdatabasesql.New(dbPool), nil) + migrator, err := rivermigrate.New(riverdatabasesql.New(dbPool), nil) + if err != nil { + panic(err) + } // Our test database starts with a full River schema. Drop it so that we can // demonstrate working migrations. This isn't necessary outside this test. diff --git a/rivermigrate/example_migrate_test.go b/rivermigrate/example_migrate_test.go index 1e4b2d54..47b572bb 100644 --- a/rivermigrate/example_migrate_test.go +++ b/rivermigrate/example_migrate_test.go @@ -29,7 +29,10 @@ func Example_migrate() { } defer tx.Rollback(ctx) - migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) + migrator, err := rivermigrate.New(riverpgxv5.New(dbPool), nil) + if err != nil { + panic(err) + } // Our test database starts with a full River schema. Drop it so that we can // demonstrate working migrations. This isn't necessary outside this test. diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go index 204767b2..0e6702ce 100644 --- a/rivermigrate/river_migrate.go +++ b/rivermigrate/river_migrate.go @@ -19,6 +19,7 @@ import ( "github.com/riverqueue/river/internal/util/dbutil" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" + "github.com/riverqueue/river/rivershared/levenshtein" "github.com/riverqueue/river/rivershared/util/maputil" "github.com/riverqueue/river/rivershared/util/randutil" "github.com/riverqueue/river/rivershared/util/sliceutil" @@ -93,8 +94,11 @@ type Migrator[TTx any] struct { // } // defer dbPool.Close() // -// migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) -func New[TTx any](driver riverdriver.Driver[TTx], config *Config) *Migrator[TTx] { +// migrator, err := rivermigrate.New(riverpgxv5.New(dbPool), nil) +// if err != nil { +// // handle error +// } +func New[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Migrator[TTx], error) { if config == nil { config = &Config{} } @@ -115,7 +119,24 @@ func New[TTx any](driver riverdriver.Driver[TTx], config *Config) *Migrator[TTx] } if !slices.Contains(driver.GetMigrationLines(), line) { - panic("migration line does not exist: " + line) + const minLevenshteinDistance = 2 + + var suggestedLines []string + for _, existingLine := range driver.GetMigrationLines() { + if distance := levenshtein.ComputeDistance(existingLine, line); distance <= minLevenshteinDistance { + suggestedLines = append(suggestedLines, "`"+existingLine+"`") + } + } + + errorStr := "migration line does not exist: " + line + switch { + case len(suggestedLines) == 1: + errorStr += fmt.Sprintf(" (did you mean %s?)", suggestedLines[0]) + case len(suggestedLines) > 1: + errorStr += fmt.Sprintf(" (did you mean one of %v?)", strings.Join(suggestedLines, ", ")) + } + + return nil, errors.New(errorStr) } riverMigrations, err := migrationsFromFS(driver.GetMigrationFS(line), line) @@ -129,7 +150,7 @@ func New[TTx any](driver riverdriver.Driver[TTx], config *Config) *Migrator[TTx] driver: driver, line: line, migrations: validateAndInit(riverMigrations), - }) + }), nil } // ExistingVersions gets the existing set of versions that have been migrated in diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go index 601115f0..b397d64f 100644 --- a/rivermigrate/river_migrate_test.go +++ b/rivermigrate/river_migrate_test.go @@ -48,12 +48,14 @@ func (d *driverWithAlternateLine) GetMigrationFS(line string) fs.FS { return d.Driver.GetMigrationFS(line) case migrationLineAlternate: return migrationFS + case migrationLineAlternate + "2": + panic(line + " is only meant for testing line suggestions") } panic("migration line does not exist: " + line) } func (d *driverWithAlternateLine) GetMigrationLines() []string { - return append(d.Driver.GetMigrationLines(), migrationLineAlternate) + return append(d.Driver.GetMigrationLines(), migrationLineAlternate, migrationLineAlternate+"2") } func TestMigrator(t *testing.T) { @@ -94,7 +96,8 @@ func TestMigrator(t *testing.T) { tx: tx, } - migrator := New(bundle.driver, &Config{Logger: bundle.logger}) + migrator, err := New(bundle.driver, &Config{Logger: bundle.logger}) + require.NoError(t, err) migrator.migrations = migrationsBundle.WithTestVersionsMap return migrator, bundle @@ -112,12 +115,41 @@ func TestMigrator(t *testing.T) { t.Cleanup(func() { require.NoError(t, tx.Rollback()) }) driver := riverdatabasesql.New(stdPool) - migrator := New(driver, &Config{Logger: bundle.logger}) + migrator, err := New(driver, &Config{Logger: bundle.logger}) + require.NoError(t, err) migrator.migrations = migrationsBundle.WithTestVersionsMap return migrator, tx } + t.Run("NewUnknownLine", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + + _, err := New(bundle.driver, &Config{Line: "unknown_line"}) + require.EqualError(t, err, "migration line does not exist: unknown_line") + + _, err = New(bundle.driver, &Config{Line: "mai"}) + require.EqualError(t, err, "migration line does not exist: mai (did you mean `main`?)") + + _, err = New(bundle.driver, &Config{Line: "maim"}) + require.EqualError(t, err, "migration line does not exist: maim (did you mean `main`?)") + + _, err = New(bundle.driver, &Config{Line: "maine"}) + require.EqualError(t, err, "migration line does not exist: maine (did you mean `main`?)") + + _, err = New(bundle.driver, &Config{Line: "ma"}) + require.EqualError(t, err, "migration line does not exist: ma (did you mean `main`?)") + + // Too far off. + _, err = New(bundle.driver, &Config{Line: "m"}) + require.EqualError(t, err, "migration line does not exist: m") + + _, err = New(bundle.driver, &Config{Line: "alternat"}) + require.EqualError(t, err, "migration line does not exist: alternat (did you mean one of `alternate`, `alternate2`?)") + }) + t.Run("AllVersions", func(t *testing.T) { t.Parallel() @@ -597,10 +629,11 @@ func TestMigrator(t *testing.T) { // We have to reinitialize the alternateMigrator because the migrations bundle is // set in the constructor. - alternateMigrator := New(bundle.driver, &Config{ + alternateMigrator, err := New(bundle.driver, &Config{ Line: migrationLineAlternate, Logger: bundle.logger, }) + require.NoError(t, err) res, err := alternateMigrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) require.NoError(t, err) @@ -633,10 +666,11 @@ func TestMigrator(t *testing.T) { _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{TargetVersion: 4}) require.NoError(t, err) - alternateMigrator := New(bundle.driver, &Config{ + alternateMigrator, err := New(bundle.driver, &Config{ Line: migrationLineAlternate, Logger: bundle.logger, }) + require.NoError(t, err) // Alternate line not allowed because `river_job.line` doesn't exist. _, err = alternateMigrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) diff --git a/rivershared/levenshtein/License.txt b/rivershared/levenshtein/License.txt new file mode 100644 index 00000000..a55defac --- /dev/null +++ b/rivershared/levenshtein/License.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2015 Agniva De Sarker + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/rivershared/levenshtein/levenshtein.go b/rivershared/levenshtein/levenshtein.go new file mode 100644 index 00000000..141d4e5a --- /dev/null +++ b/rivershared/levenshtein/levenshtein.go @@ -0,0 +1,86 @@ +// Package levenshtein is a Go implementation to calculate Levenshtein Distance. +// +// Vendored from this repository: +// https://github.com/agnivade/levenshtein +// +// Implementation taken from +// https://gist.github.com/andrei-m/982927#gistcomment-1931258 +package levenshtein + +import "unicode/utf8" + +// minLengthThreshold is the length of the string beyond which +// an allocation will be made. Strings smaller than this will be +// zero alloc. +const minLengthThreshold = 32 + +// ComputeDistance computes the levenshtein distance between the two +// strings passed as an argument. The return value is the levenshtein distance +// +// Works on runes (Unicode code points) but does not normalize +// the input strings. See https://blog.golang.org/normalization +// and the golang.org/x/text/unicode/norm package. +func ComputeDistance(str1, str2 string) int { + if len(str1) == 0 { + return utf8.RuneCountInString(str2) + } + + if len(str2) == 0 { + return utf8.RuneCountInString(str1) + } + + if str1 == str2 { + return 0 + } + + // We need to convert to []rune if the strings are non-ASCII. + // This could be avoided by using utf8.RuneCountInString + // and then doing some juggling with rune indices, + // but leads to far more bounds checks. It is a reasonable trade-off. + runeSlice1 := []rune(str1) + runeSlice2 := []rune(str2) + + // swap to save some memory O(min(a,b)) instead of O(a) + if len(runeSlice1) > len(runeSlice2) { + runeSlice1, runeSlice2 = runeSlice2, runeSlice1 + } + lenRuneSlice1 := len(runeSlice1) + lenRuneSlice2 := len(runeSlice2) + + // Init the row. + var distances []uint16 + if lenRuneSlice1+1 > minLengthThreshold { + distances = make([]uint16, lenRuneSlice1+1) + } else { + // We make a small optimization here for small strings. Because a slice + // of constant length is effectively an array, it does not allocate. So + // we can re-slice it to the right length as long as it is below a + // desired threshold. + distances = make([]uint16, minLengthThreshold) + distances = distances[:lenRuneSlice1+1] + } + + // we start from 1 because index 0 is already 0. + for i := 1; i < len(distances); i++ { + distances[i] = uint16(i) + } + + // Make a dummy bounds check to prevent the 2 bounds check down below. The + // one inside the loop is particularly costly. + _ = distances[lenRuneSlice1] + + // fill in the rest + for i := 1; i <= lenRuneSlice2; i++ { + prev := uint16(i) + for j := 1; j <= lenRuneSlice1; j++ { + current := distances[j-1] // match + if runeSlice2[i-1] != runeSlice1[j-1] { + current = min(min(distances[j-1]+1, prev+1), distances[j]+1) + } + distances[j-1] = prev + prev = current + } + distances[lenRuneSlice1] = prev + } + return int(distances[lenRuneSlice1]) +} diff --git a/rivershared/levenshtein/levenshtein_test.go b/rivershared/levenshtein/levenshtein_test.go new file mode 100644 index 00000000..c842c640 --- /dev/null +++ b/rivershared/levenshtein/levenshtein_test.go @@ -0,0 +1,88 @@ +package levenshtein_test + +import ( + "testing" + + "github.com/riverqueue/river/rivershared/levenshtein" +) + +func TestSanity(t *testing.T) { + t.Parallel() + + tests := []struct { + str1, str2 string + want int + }{ + {"", "hello", 5}, + {"hello", "", 5}, + {"hello", "hello", 0}, + {"ab", "aa", 1}, + {"ab", "ba", 2}, + {"ab", "aaa", 2}, + {"bbb", "a", 3}, + {"kitten", "sitting", 3}, + {"distance", "difference", 5}, + {"levenshtein", "frankenstein", 6}, + {"resume and cafe", "resumes and cafes", 2}, + {"a very long string that is meant to exceed", "another very long string that is meant to exceed", 6}, + } + for i, d := range tests { + n := levenshtein.ComputeDistance(d.str1, d.str2) + if n != d.want { + t.Errorf("Test[%d]: ComputeDistance(%q,%q) returned %v, want %v", + i, d.str1, d.str2, n, d.want) + } + } +} + +func TestUnicode(t *testing.T) { + t.Parallel() + + tests := []struct { + str1, str2 string + want int + }{ + // Testing acutes and umlauts + {"resumé and café", "resumés and cafés", 2}, + {"resume and cafe", "resumé and café", 2}, + {"Hafþór Júlíus Björnsson", "Hafþor Julius Bjornsson", 4}, + // Only 2 characters are less in the 2nd string + {"།་གམ་འས་པ་་མ།", "།་གམའས་པ་་མ", 2}, + } + for i, d := range tests { + n := levenshtein.ComputeDistance(d.str1, d.str2) + if n != d.want { + t.Errorf("Test[%d]: ComputeDistance(%q,%q) returned %v, want %v", + i, d.str1, d.str2, n, d.want) + } + } +} + +// Benchmarks +// ---------------------------------------------- +var sink int //nolint:gochecknoglobals + +func BenchmarkSimple(b *testing.B) { + tests := []struct { + a, b string + name string + }{ + // ASCII + {"levenshtein", "frankenstein", "ASCII"}, + // Testing acutes and umlauts + {"resumé and café", "resumés and cafés", "French"}, + {"Hafþór Júlíus Björnsson", "Hafþor Julius Bjornsson", "Nordic"}, + {"a very long string that is meant to exceed", "another very long string that is meant to exceed", "long string"}, + // Only 2 characters are less in the 2nd string + {"།་གམ་འས་པ་་མ།", "།་གམའས་པ་་མ", "Tibetan"}, + } + tmp := 0 + for _, test := range tests { + b.Run(test.name, func(b *testing.B) { + for n := 0; n < b.N; n++ { + tmp = levenshtein.ComputeDistance(test.a, test.b) + } + }) + } + sink = tmp +}