diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml new file mode 100644 index 0000000..5b3eb3b --- /dev/null +++ b/.github/workflows/go.yml @@ -0,0 +1,33 @@ +name: Go + +on: + push: + branches: [ "**" ] + pull_request: + branches: [ master ] + +jobs: + + build: + name: Build + runs-on: ubuntu-latest + + strategy: + matrix: + go: ['1.21', '1.22'] + + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v3 + + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + id: go + + - name: Vet + run: make vet + + - name: Test + run: make test diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6924d34 --- /dev/null +++ b/Makefile @@ -0,0 +1,23 @@ +PROJECT_NAME := "lu" +PKG := "github.com/luno/$(PROJECT_NAME)" +PKG_LIST := $(shell go list ${PKG}/... | grep -v /vendor/) + +.PHONY: vet fmt checkfmt test race + +vet: ## Lint the files + @go vet ${PKG_LIST} + +fmt: ## Format the files + @gofumpt -w . + +checkfmt: ## Check that files are formatted + @./checkfmt.sh + +test: ## Run unittests + @go test -short ${PKG_LIST} + +race: ## Run data race detector + @go test -race -short ${PKG_LIST} + +help: ## Display this help screen + @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/app.go b/app.go new file mode 100644 index 0000000..2422218 --- /dev/null +++ b/app.go @@ -0,0 +1,395 @@ +// Package lu is an application framework for Luno microservices +package lu + +import ( + "context" + "runtime/pprof" + "sync" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/log" + "golang.org/x/sync/errgroup" + "k8s.io/utils/clock" +) + +var errProcessStillRunning = errors.New("process still running after shutdown", j.C("ERR_fa232f807b75bab6")) + +// App will manage the lifecycle of the service. Emitting events for each stage of the application. +type App struct { + // StartupTimeout is the deadline for running the start-up hooks and starting all the Processes + // Defaults to 15 seconds. + StartupTimeout time.Duration + + // ShutdownTimeout is the deadline for stopping all the app Processes and + // running the shutdown hooks. + // Defaults to 15 seconds. + ShutdownTimeout time.Duration + + // OnEvent will be called for every lifecycle event in the app. See EventType for details. + OnEvent OnEvent + + // UseProcessFile will write a file at /tmp/lu.pid whilst the app is still running. + // The file will be removed after a graceful shutdown. + UseProcessFile bool + + // OnShutdownErr is called after failing to shut down cleanly. + // You can use this hook to change the error or do last minute reporting. + // This hook is only called when using Run not when using Shutdown + OnShutdownErr func(ctx context.Context, err error) error + + startupHooks []hook + shutdownHooks []hook + + processes []Process + processRunning []chan struct{} + ctx context.Context + eg *errgroup.Group + cancel context.CancelFunc +} + +func (a *App) setDefaults() { + if a.StartupTimeout == 0 { + a.StartupTimeout = 15 * time.Second + } + if a.ShutdownTimeout == 0 { + a.ShutdownTimeout = 15 * time.Second + } + if a.OnEvent == nil { + a.OnEvent = func(context.Context, Event) {} + } +} + +// OnStartUp will call f before the app starts working +func (a *App) OnStartUp(f ProcessFunc, opts ...HookOption) { + h := hook{F: f, createOrder: len(a.startupHooks)} + applyHookOptions(&h, opts) + a.startupHooks = append(a.startupHooks, h) + sortHooks(a.startupHooks) +} + +// OnShutdown will call f just before the application terminates +// Use this to close database connections or release resources +func (a *App) OnShutdown(f ProcessFunc, opts ...HookOption) { + h := hook{F: f, createOrder: len(a.shutdownHooks)} + applyHookOptions(&h, opts) + a.shutdownHooks = append(a.shutdownHooks, h) + sortHooks(a.shutdownHooks) +} + +// AddProcess adds a Process that is started in parallel after start up. +// If any Process finish with an error, then the application will be stopped. +func (a *App) AddProcess(processes ...Process) { + a.processes = append(a.processes, processes...) +} + +// GetProcesses returns all the configured processes for the App +func (a *App) GetProcesses() []Process { + ret := make([]Process, len(a.processes)) + copy(ret, a.processes) + return ret +} + +func (a *App) startup(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, a.StartupTimeout) + defer cancel() + // Revert the labels after running all the hooks + defer pprof.SetGoroutineLabels(ctx) + + for idx, h := range a.startupHooks { + if ctx.Err() != nil { + return ctx.Err() + } + a.OnEvent(ctx, Event{Type: PreHookStart, Name: h.Name}) + hookCtx := ctx + if h.Name != "" { + hookCtx = log.ContextWith(hookCtx, j.MKV{"hook_idx": idx, "hook_name": h.Name}) + hookCtx = pprof.WithLabels(hookCtx, pprof.Labels("lu_hook", h.Name)) + pprof.SetGoroutineLabels(hookCtx) + } + + if err := h.F(hookCtx); err != nil { + return errors.Wrap(err, "start hook") + } + a.OnEvent(ctx, Event{Type: PostHookStart, Name: h.Name}) + } + return ctx.Err() +} + +func (a *App) cleanup(ctx context.Context) error { + var errs []error + for idx, h := range a.shutdownHooks { + if ctx.Err() != nil { + return ctx.Err() + } + a.OnEvent(ctx, Event{Type: PreHookStop, Name: h.Name}) + hookCtx := log.ContextWith(ctx, j.MKV{"hook_idx": idx, "hook_name": h.Name}) + err := h.F(hookCtx) + if err != nil { + // NoReturnErr: Collect errors + errs = append(errs, errors.Wrap(err, "stop hook", j.KV("hook_name", h.Name))) + } + a.OnEvent(ctx, Event{Type: PostHookStop, Name: h.Name}) + } + // TODO(adam): Return all the errors + if len(errs) > 0 { + for i := 1; i < len(errs); i++ { + log.Error(ctx, errs[i]) + } + return errs[0] + } + return nil +} + +// Run will start the App, running the startup Hooks, then the Processes. +// It will wait for any signals before shutting down first the Processes then the shutdown Hooks. +// This behaviour can be customised by using Launch, WaitForShutdown, and Shutdown. +func (a *App) Run() int { + ac := NewAppContext(context.Background()) + defer ac.Stop() + + ctx := ac.AppContext + + if err := a.Launch(ctx); err != nil { + // NoReturnErr: Log + log.Error(ctx, errors.Wrap(err, "app launch")) + return 1 + } + <-a.WaitForShutdown() + var exit int + err := a.Shutdown() + if err != nil { + // NoReturnErr: Log + err = handleShutdownErr(a, ac, err) + log.Error(ctx, errors.Wrap(err, "app shutdown")) + exit = 1 + } + + // TODO(adam): Move pid removal into Shutdown + + // This should be called in Shutdown so that clients which call that instead of + // Run can get the right behaviour + if a.UseProcessFile { + removePIDFile(ctx) + } + + // Wait for termination in case we've only been told to quit + <-ac.TerminationContext.Done() + + log.Info(ctx, "App terminated", j.MKV{"exit_code": exit}) + + return exit +} + +// Launch will run all the startup hooks and launch all the processes. +// If any hook returns an error, we will return early, processes will not be started. +// ctx will be used for startup and also the main application context. +// If the hooks take longer than StartupTimeout then launch will return a deadline exceeded error. +func (a *App) Launch(ctx context.Context) error { + a.setDefaults() + + if a.UseProcessFile { + if err := createPIDFile(); err != nil { + return err + } + } + + a.OnEvent(ctx, Event{Type: AppStartup}) + + if err := a.startup(ctx); err != nil { + return err + } + + a.OnEvent(ctx, Event{Type: AppRunning}) + + // Create the app context now + appCtx, appCancel := context.WithCancel(ctx) + eg, appCtx := errgroup.WithContext(appCtx) + + a.ctx = appCtx + a.cancel = appCancel + a.eg = eg + + a.processRunning = make([]chan struct{}, len(a.processes)) + for i := range a.processes { + p := &a.processes[i] + p.app = a + + doneCh := make(chan struct{}) + a.processRunning[i] = doneCh + if p.Run == nil { + close(doneCh) + continue + } + ctx := a.ctx + if p.Name != "" { + ctx = log.ContextWith(ctx, j.KV("process", p.Name)) + ctx = pprof.WithLabels(ctx, pprof.Labels("lu_process", p.Name)) + } + + eg.Go(func() error { + pprof.SetGoroutineLabels(ctx) + defer close(doneCh) + a.OnEvent(ctx, Event{Type: ProcessStart, Name: p.Name}) + defer a.OnEvent(ctx, Event{Type: ProcessEnd, Name: p.Name}) + // NOTE: Any error returned by any of the processes will cause the entire App to terminate + return p.Run(ctx) + }) + } + return ctx.Err() +} + +// WaitForShutdown returns a channel that waits for the application to be cancelled. +// Note the application has not finished terminating when this channel is closed. +// Shutdown should be called after waiting on the channel from this function. +func (a *App) WaitForShutdown() <-chan struct{} { + return a.ctx.Done() +} + +// Shutdown will synchronously stop all the resources running in the app. +func (a *App) Shutdown() error { + ctx, cancel := context.WithTimeout(context.Background(), a.ShutdownTimeout) + defer cancel() + + a.OnEvent(ctx, Event{Type: AppTerminating}) + defer a.OnEvent(ctx, Event{Type: AppTerminated}) + + defer func() { + err := a.cleanup(ctx) + if err != nil { + // NoReturnErr: Log + log.Error(ctx, errors.Wrap(err, "")) + } + }() + + shutErrs := make(chan error) + var shutCount int + // Shutdown processes which need shutting down explicitly first + for i := range a.processes { + p := &a.processes[i] + if p.Shutdown != nil { + shutCount++ + go func() { + if err := p.Shutdown(ctx); err != nil { + // NoReturnErr: Send error to collector + shutErrs <- errors.Wrap(err, "", j.KV("process", p.Name)) + } + shutErrs <- nil + }() + } + } + + var errs []error + for i := 0; i < shutCount; i++ { + shutErr, err := WaitFor(ctx, shutErrs) + if err != nil { + return err + } + if shutErr != nil { + // NoReturnErr: Collect for later + errs = append(errs, shutErr) + } + } + + // Cancel the context for all the other processes + a.cancel() + + groupErr, err := WaitFor(ctx, ErrGroupWait(a.eg)) + if err != nil { + return err + } + if groupErr != nil && !errors.Is(groupErr, context.Canceled) { + // NoReturnErr: Store them up + errs = append(errs, groupErr) + } + + if len(errs) > 0 { + for i := 1; i < len(errs); i++ { + log.Error(ctx, errs[i]) + } + return errs[0] + } + + return nil +} + +func (a *App) RunningProcesses() []string { + var ret []string + for idx, p := range a.processes { + select { + case <-a.processRunning[idx]: + default: + ret = append(ret, p.Name) + } + } + return ret +} + +// Wait is a cancellable wait, it will return either when +// d has passed or ctx is cancelled. +// It will return an error if cancelled early. +func Wait(ctx context.Context, cl clock.Clock, d time.Duration) error { + if d <= 0 { + return ctx.Err() + } + ti := cl.NewTimer(d) + defer ti.Stop() + _, err := WaitFor(ctx, ti.C()) + return err +} + +func WaitUntil(ctx context.Context, cl clock.Clock, t time.Time) error { + return Wait(ctx, cl, t.Sub(cl.Now())) +} + +func ErrGroupWait(eg *errgroup.Group) <-chan error { + ch := make(chan error) + go func() { + ch <- eg.Wait() + }() + return ch +} + +func WaitFor[T any](ctx context.Context, ch <-chan T) (T, error) { + select { + case v := <-ch: + return v, nil + case <-ctx.Done(): + var v T + return v, ctx.Err() + } +} + +// SyncGroupWait wait for the wait group (websocket connections) to finalize +func SyncGroupWait(wg *sync.WaitGroup) <-chan struct{} { + ch := make(chan struct{}) + go func() { + defer close(ch) + wg.Wait() + }() + return ch +} + +func handleShutdownErr(a *App, ac AppContext, err error) error { + if !errors.Is(err, context.DeadlineExceeded) { + return err + } + running := a.RunningProcesses() + if len(running) == 0 { + return err + } + errs := make([]error, 0, len(running)) + for _, p := range running { + err := errors.Wrap(errProcessStillRunning, "", j.KV("process", p)) + errs = append(errs, err) + } + err = errors.Join(errs...) + if ac.TerminationContext.Err() != nil { + return err + } + if a.OnShutdownErr != nil { + return a.OnShutdownErr(ac.TerminationContext, err) + } + return err +} diff --git a/app_test.go b/app_test.go new file mode 100644 index 0000000..964a77f --- /dev/null +++ b/app_test.go @@ -0,0 +1,237 @@ +package lu_test + +import ( + "context" + "testing" + "time" + + "github.com/luno/jettison/jtest" + "github.com/luno/jettison/log" + "github.com/stretchr/testify/require" + + "github.com/luno/lu" + "github.com/luno/lu/process" + "github.com/luno/lu/test" +) + +func TestLifecycle(t *testing.T) { + ev := make(test.EventLog, 100) + a := &lu.App{OnEvent: ev.Append} + + a.OnStartUp(func(ctx context.Context) error { + log.Info(ctx, "starting up") + return nil + }, lu.WithHookName("basic start hook")) + + a.OnShutdown(func(ctx context.Context) error { + log.Info(ctx, "stopping") + return nil + }, lu.WithHookName("basic stop hook")) + + a.AddProcess( + lu.Process{ + Name: "one", + Run: func(ctx context.Context) error { + log.Info(ctx, "one") + <-ctx.Done() + return ctx.Err() + }, + }, + lu.Process{ + Name: "two", + Run: func(ctx context.Context) error { + log.Info(ctx, "two") + <-ctx.Done() + return ctx.Err() + }, + }, + lu.Process{ + Name: "three", + Run: func(ctx context.Context) error { + log.Info(ctx, "three") + <-ctx.Done() + return ctx.Err() + }, + }, + process.ContextLoop( + func(ctx context.Context) (context.Context, context.CancelFunc, error) { return ctx, func() {}, nil }, + func(ctx context.Context) error { return process.ErrBreakContextLoop }, + process.WithName("break loop"), + process.WithBreakableLoop()), + ) + + err := a.Launch(context.Background()) + jtest.AssertNil(t, err) + + time.Sleep(250 * time.Millisecond) + + err = a.Shutdown() + jtest.AssertNil(t, err) + + close(ev) + test.AssertEvents(t, ev, + test.Event{Type: lu.AppStartup}, + test.Event{Type: lu.PreHookStart, Name: "basic start hook"}, + test.Event{Type: lu.PostHookStart, Name: "basic start hook"}, + test.Event{Type: lu.AppRunning}, + test.AnyOrder( + test.Event{Type: lu.ProcessStart, Name: "one"}, + test.Event{Type: lu.ProcessStart, Name: "two"}, + test.Event{Type: lu.ProcessStart, Name: "three"}, + test.Event{Type: lu.ProcessStart, Name: "break loop"}, + test.Event{Type: lu.ProcessEnd, Name: "break loop"}, + ), + test.Event{Type: lu.AppTerminating}, + test.AnyOrder( + test.Event{Type: lu.ProcessEnd, Name: "one"}, + test.Event{Type: lu.ProcessEnd, Name: "two"}, + test.Event{Type: lu.ProcessEnd, Name: "three"}, + ), + test.Event{Type: lu.PreHookStop, Name: "basic stop hook"}, + test.Event{Type: lu.PostHookStop, Name: "basic stop hook"}, + test.Event{Type: lu.AppTerminated}, + ) +} + +func TestShutdownWithParentContext(t *testing.T) { + var a lu.App + a.AddProcess(lu.Process{ + Run: func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + err := a.Launch(ctx) + jtest.AssertNil(t, err) + + require.Eventually(t, func() bool { + select { + case <-a.WaitForShutdown(): + return true + default: + return false + } + }, 2*time.Second, 100*time.Millisecond) + + err = a.Shutdown() + jtest.Assert(t, context.DeadlineExceeded, err) +} + +func TestProcessShutdown(t *testing.T) { + testCases := []struct { + name string + setupApp func(a *lu.App) + + expErr error + }{ + {name: "empty"}, + { + name: "cancellable", + setupApp: func(a *lu.App) { + a.ShutdownTimeout = 100 * time.Millisecond + a.AddProcess(lu.Process{Shutdown: func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }}) + }, + expErr: context.DeadlineExceeded, + }, + { + name: "dependents", + setupApp: func(a *lu.App) { + ch := make(chan struct{}) + p1 := lu.Process{Shutdown: func(ctx context.Context) error { <-ch; return nil }} + p2 := lu.Process{Shutdown: func(ctx context.Context) error { close(ch); return nil }} + p3 := lu.Process{Shutdown: func(ctx context.Context) error { <-ch; return nil }} + a.AddProcess(p1, p2, p3) + }, + }, + { + name: "blocked", + setupApp: func(a *lu.App) { + a.ShutdownTimeout = 100 * time.Millisecond + ch := make(chan struct{}) + a.AddProcess(lu.Process{Shutdown: func(ctx context.Context) error { <-ch; return nil }}) + }, + expErr: context.DeadlineExceeded, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var a lu.App + if tc.setupApp != nil { + tc.setupApp(&a) + } + + err := a.Launch(context.Background()) + jtest.RequireNil(t, err) + + err = a.Shutdown() + jtest.Assert(t, tc.expErr, err) + }) + } +} + +func TestRunningProcesses(t *testing.T) { + testCases := []struct { + name string + processes []lu.Process + expShutdownError error + expRunning []string + }{ + {name: "nil"}, + { + name: "blocker", + processes: []lu.Process{ + {Name: "blocker", Run: func(ctx context.Context) error { + var c chan struct{} + <-c + return nil + }}, + }, + expShutdownError: context.DeadlineExceeded, + expRunning: []string{"blocker"}, + }, + { + name: "non-blocker", + processes: []lu.Process{ + {Name: "gogo", Run: func(ctx context.Context) error { + <-ctx.Done() + return nil + }}, + }, + }, + { + name: "one blocker among others", + processes: []lu.Process{ + {Name: "gogo", Run: func(ctx context.Context) error { + <-ctx.Done() + return nil + }}, + {Name: "blocker", Run: func(ctx context.Context) error { + var c chan struct{} + <-c + return nil + }}, + }, + expShutdownError: context.DeadlineExceeded, + expRunning: []string{"blocker"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + a := lu.App{ShutdownTimeout: 100 * time.Millisecond} + a.AddProcess(tc.processes...) + + jtest.RequireNil(t, a.Launch(context.Background())) + jtest.Assert(t, tc.expShutdownError, a.Shutdown()) + require.Equal(t, tc.expRunning, a.RunningProcesses()) + }) + } +} diff --git a/checkfmt.sh b/checkfmt.sh new file mode 100755 index 0000000..f6b2bd5 --- /dev/null +++ b/checkfmt.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +test -z "$(gofumpt -d -e . | tee /dev/stderr)" diff --git a/event.go b/event.go new file mode 100644 index 0000000..353935e --- /dev/null +++ b/event.go @@ -0,0 +1,28 @@ +package lu + +import "context" + +//go:generate stringer -type=EventType + +type OnEvent func(context.Context, Event) + +type EventType int + +const ( + Unknown EventType = iota + AppStartup // First event, emitted right at the start + PreHookStart // Emitted just before running each Hook.Start + PostHookStart // Emitted just after completing a Hook.Start + AppRunning // Emitted after running every startup Hook + ProcessStart // Emitted before starting to run a Process + ProcessEnd // Emitted when a Process terminates + AppTerminating // Emitted when the application starts termination + PreHookStop // Emitted before running each Hook.Stop + PostHookStop // Emitted after running each Hook.Stop + AppTerminated // Emitted before calling os.Exit +) + +type Event struct { + Type EventType + Name string +} diff --git a/eventtype_string.go b/eventtype_string.go new file mode 100644 index 0000000..e14bcf6 --- /dev/null +++ b/eventtype_string.go @@ -0,0 +1,33 @@ +// Code generated by "stringer -type=EventType"; DO NOT EDIT. + +package lu + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[Unknown-0] + _ = x[AppStartup-1] + _ = x[PreHookStart-2] + _ = x[PostHookStart-3] + _ = x[AppRunning-4] + _ = x[ProcessStart-5] + _ = x[ProcessEnd-6] + _ = x[AppTerminating-7] + _ = x[PreHookStop-8] + _ = x[PostHookStop-9] + _ = x[AppTerminated-10] +} + +const _EventType_name = "UnknownAppStartupPreHookStartPostHookStartAppRunningProcessStartProcessEndAppTerminatingPreHookStopPostHookStopAppTerminated" + +var _EventType_index = [...]uint8{0, 7, 17, 29, 42, 52, 64, 74, 88, 99, 111, 124} + +func (i EventType) String() string { + if i < 0 || i >= EventType(len(_EventType_index)-1) { + return "EventType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _EventType_name[_EventType_index[i]:_EventType_index[i+1]] +} diff --git a/file.go b/file.go new file mode 100644 index 0000000..6163950 --- /dev/null +++ b/file.go @@ -0,0 +1,42 @@ +package lu + +import ( + "context" + "os" + "strconv" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/log" +) + +const fileName = "/tmp/lu.pid" + +func createPIDFile() error { + f, err := os.OpenFile(fileName, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o666) + if errors.Is(err, os.ErrExist) { + kv := j.MKV{"my_pid": os.Getpid(), "file": fileName, "open_err": err.Error()} + contents, readErr := os.ReadFile(fileName) + if readErr != nil { + // NoReturnErr: Something up with the file, add the error to the original one + kv["read_err"] = readErr.Error() + } else { + kv["existing_pid"] = string(contents) + } + return errors.New("process already running", kv) + } + defer f.Close() + _, err = f.WriteString(strconv.Itoa(os.Getpid())) + if err != nil { + return errors.Wrap(err, "creating pid file", j.KV("file", fileName)) + } + return nil +} + +func removePIDFile(ctx context.Context) { + err := os.Remove(fileName) + if err != nil { + // NoReturnErr: We'll terminate after this so just log + log.Error(ctx, errors.Wrap(err, "remove pid file", j.KV("file", fileName))) + } +} diff --git a/file_test.go b/file_test.go new file mode 100644 index 0000000..31e223d --- /dev/null +++ b/file_test.go @@ -0,0 +1,24 @@ +package lu + +import ( + "context" + "os" + "testing" + + "github.com/luno/jettison/jtest" + "github.com/stretchr/testify/assert" +) + +func TestPidFile(t *testing.T) { + err := createPIDFile() + jtest.RequireNil(t, err) + + contents, err := os.ReadFile(fileName) + jtest.RequireNil(t, err) + assert.NotEmpty(t, string(contents)) + + removePIDFile(context.Background()) + + _, err = os.ReadFile(fileName) + jtest.Assert(t, os.ErrNotExist, err) +} diff --git a/go.mod b/go.mod index fbb3e1d..0e44d57 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,36 @@ module github.com/luno/lu go 1.22.3 + +require ( + github.com/go-stack/stack v1.8.1 + github.com/luno/jettison v0.0.0-20240625085333-8727b580c646 + github.com/luno/reflex v0.0.0-20240628090425-1f5700278387 + github.com/prometheus/client_golang v1.19.1 + github.com/robfig/cron/v3 v3.0.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/sync v0.7.0 + k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 +) + +require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.5.0 // indirect + github.com/prometheus/common v0.48.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + go.opentelemetry.io/otel v1.14.0 // indirect + go.opentelemetry.io/otel/trace v1.14.0 // indirect + golang.org/x/net v0.24.0 // indirect + golang.org/x/sys v0.19.0 // indirect + golang.org/x/text v0.14.0 // indirect + golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/grpc v1.63.2 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ecf4a67 --- /dev/null +++ b/go.sum @@ -0,0 +1,76 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-stack/stack v1.8.1 h1:ntEHSVwIt7PNXNpgPmVfMrNhLtgjlmnZha2kOpuRiDw= +github.com/go-stack/stack v1.8.1/go.mod h1:dcoOX6HbPZSZptuspn9bctJ+N/CnF5gGygcUP3XYfe4= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/luno/jettison v0.0.0-20240625085333-8727b580c646 h1:9U2KNY3ZeE5uC6d8Nc/Sn9ONUwLirAG1hYQg493i7kE= +github.com/luno/jettison v0.0.0-20240625085333-8727b580c646/go.mod h1:cV8KOstEDY+Su4dcN1dadoXC7xmyEqtXAw6Nywia/z8= +github.com/luno/reflex v0.0.0-20240628090425-1f5700278387 h1:wAl+iGthdDisVsg2URlo3BZtDrnWyO0Fv6xL/D8c1ro= +github.com/luno/reflex v0.0.0-20240628090425-1f5700278387/go.mod h1:FdFAF2wOACOnkxOb5OUx3lNxHW4hCAaulnoJzalphpY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= +github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= +github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= +github.com/sebdah/goldie/v2 v2.5.3 h1:9ES/mNN+HNUbNWpVAlrzuZ7jE+Nrczbj8uFRjM7624Y= +github.com/sebdah/goldie/v2 v2.5.3/go.mod h1:oZ9fp0+se1eapSRjfYbsV/0Hqhbuu3bJVvKI/NNtssI= +github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= +github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/otel v1.14.0 h1:/79Huy8wbf5DnIPhemGB+zEPVwnN6fuQybr/SRXa6hM= +go.opentelemetry.io/otel v1.14.0/go.mod h1:o4buv+dJzx8rohcUeRmWUZhqupFvzWis188WlggnNeU= +go.opentelemetry.io/otel/sdk v1.14.0 h1:PDCppFRDq8A1jL9v6KMI6dYesaq+DFcDZvjsoGvxGzY= +go.opentelemetry.io/otel/sdk v1.14.0/go.mod h1:bwIC5TjrNG6QDCHNWvW4HLHtUQ4I+VQDsnjhvyZCALM= +go.opentelemetry.io/otel/trace v1.14.0 h1:wp2Mmvj41tDsyAJXiWDWpfNsOiIyd38fy85pyKcFq/M= +go.opentelemetry.io/otel/trace v1.14.0/go.mod h1:8avnQLK+CG77yNLUae4ea2JDQ6iT+gozhnZjy/rw9G8= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= +golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY= +google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= +google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0 h1:jgGTlFYnhF1PM1Ax/lAlxUPE+KfCIXHaathvJg1C3ak= +k8s.io/utils v0.0.0-20240502163921-fe8a2dddb1d0/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= diff --git a/hook.go b/hook.go new file mode 100644 index 0000000..0b9386f --- /dev/null +++ b/hook.go @@ -0,0 +1,61 @@ +package lu + +import ( + "context" + "fmt" + "sort" +) + +type hook struct { + Name string + createOrder int + Priority HookPriority + // F is called either at the start or at the end of the application lifecycle + // ctx will be cancelled if the function takes too long + F func(ctx context.Context) error +} + +func sortHooks(h []hook) { + sort.Slice(h, func(i, j int) bool { + hi, hj := h[i], h[j] + if hi.Priority != hj.Priority { + return hi.Priority < hj.Priority + } + return hi.createOrder < hj.createOrder + }) +} + +type HookOption func(*hook) + +func applyHookOptions(h *hook, opts []HookOption) { + for _, o := range opts { + o(h) + } +} + +// WithHookName is used for logging so each Hook can be identified +func WithHookName(s string) HookOption { + return func(options *hook) { + options.Name = s + } +} + +type HookPriority int + +const ( + HookPriorityFirst HookPriority = -100 + HookPriorityDefault HookPriority = 0 + HookPriorityLast HookPriority = 100 +) + +// WithHookPriority controls the order in which hooks are run, the lower the value of p +// the earlier it will be run (compared to other hooks) +// The default priority is 0, negative priorities will be run before positive ones +func WithHookPriority(p HookPriority) HookOption { + if p < HookPriorityFirst || p > HookPriorityLast { + panic(fmt.Sprintln("invalid hook priority", p)) + } + return func(options *hook) { + options.Priority = p + } +} diff --git a/hook_test.go b/hook_test.go new file mode 100644 index 0000000..0ca00d7 --- /dev/null +++ b/hook_test.go @@ -0,0 +1,108 @@ +package lu + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSortHooks(t *testing.T) { + testCases := []struct { + name string + hooks []hook + expSorted []hook + }{ + { + name: "by order", + hooks: []hook{ + {Name: "c", createOrder: 3}, + {Name: "a", createOrder: 1}, + {Name: "b", createOrder: 2}, + }, + expSorted: []hook{ + {Name: "a", createOrder: 1}, + {Name: "b", createOrder: 2}, + {Name: "c", createOrder: 3}, + }, + }, + { + name: "by priority", + hooks: []hook{ + {Priority: 2}, + {Priority: 3}, + {Priority: 1}, + }, + expSorted: []hook{ + {Priority: 1}, + {Priority: 2}, + {Priority: 3}, + }, + }, + { + name: "by both", + hooks: []hook{ + {createOrder: 1, Priority: HookPriorityDefault}, + {createOrder: 2, Priority: HookPriorityDefault}, + {createOrder: 3, Priority: HookPriorityLast}, + {createOrder: 4, Priority: HookPriorityDefault}, + {createOrder: 5, Priority: HookPriorityFirst}, + }, + expSorted: []hook{ + {createOrder: 5, Priority: HookPriorityFirst}, + {createOrder: 1, Priority: HookPriorityDefault}, + {createOrder: 2, Priority: HookPriorityDefault}, + {createOrder: 4, Priority: HookPriorityDefault}, + {createOrder: 3, Priority: HookPriorityLast}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sortedHooks := make([]hook, len(tc.hooks)) + copy(sortedHooks, tc.hooks) + sortHooks(sortedHooks) + assert.Equal(t, tc.expSorted, sortedHooks) + }) + } +} + +func TestOptions(t *testing.T) { + testCases := []struct { + name string + options []HookOption + expHook hook + }{ + { + name: "defaults", + expHook: hook{Priority: HookPriorityDefault}, + }, + { + name: "name", + options: []HookOption{WithHookName("test name")}, + expHook: hook{Name: "test name", Priority: HookPriorityDefault}, + }, + { + name: "priority", + options: []HookOption{WithHookPriority(23)}, + expHook: hook{Priority: 23}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var h hook + applyHookOptions(&h, tc.options) + assert.Equal(t, tc.expHook, h) + }) + } +} + +func TestPriorityPanic(t *testing.T) { + assert.Panics(t, func() { + WithHookPriority(-101) + }) + assert.Panics(t, func() { + WithHookPriority(101) + }) +} diff --git a/metrics.go b/metrics.go new file mode 100644 index 0000000..8435125 --- /dev/null +++ b/metrics.go @@ -0,0 +1,12 @@ +package lu + +import "github.com/prometheus/client_golang/prometheus" + +var luUp = prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "lu_up", + Help: "A boolean metric to signal that the application used the Lu package to start running", +}) + +func init() { + prometheus.MustRegister(luUp) +} diff --git a/process.go b/process.go new file mode 100644 index 0000000..cdaa6c8 --- /dev/null +++ b/process.go @@ -0,0 +1,27 @@ +package lu + +import ( + "context" +) + +// ProcessFunc is a core process. See Process.Run for more details +type ProcessFunc func(ctx context.Context) error + +// Process will be a long-running part of the application which, +// if/when it errors, should bring the application down with it. +// It takes a context, if that context is canceled then the Process +// should return as soon as possible. +type Process struct { + app *App // Will be set before the process is Run + + // Name is used for logging lifecycle events with the Process + Name string + // Run takes a context, if that context is canceled then the ProcessFunc + // should return as soon as possible + // If Run returns an error, the application will begin the shutdown procedure + Run ProcessFunc + // Shutdown will be called to terminate the Process + // prior to cancelling the Run context. + // This is for Processes where synchronous shutdown is necessary + Shutdown func(ctx context.Context) error +} diff --git a/process/example_test.go b/process/example_test.go new file mode 100644 index 0000000..62e4cbe --- /dev/null +++ b/process/example_test.go @@ -0,0 +1,29 @@ +//go:build slowtests + +package process_test + +import ( + "bitx/app/lu/process" + "context" + "fmt" + "time" + + "github.com/luno/jettison/errors" +) + +func ExampleWithErrorSleepFunc() { + t0 := time.Now() + f := func(ctx context.Context) error { + fmt.Printf("Running for %d seconds\n", int(time.Since(t0).Seconds())) + return errors.New("error") + } + p := process.Loop(f, process.WithErrorSleep(5*time.Second)) + + ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + defer cancel() + _ = p.Run(ctx) + // Output: Running for 0 seconds + // Running for 5 seconds + // Running for 10 seconds + // Running for 15 seconds +} diff --git a/process/http.go b/process/http.go new file mode 100644 index 0000000..e3ae768 --- /dev/null +++ b/process/http.go @@ -0,0 +1,48 @@ +package process + +import ( + "context" + "net/http" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/log" + + "github.com/luno/lu" +) + +// HTTP integrates a http.Server as an App Process +func HTTP(name string, server *http.Server) lu.Process { + p := lu.Process{ + Name: "http " + name, + Run: func(ctx context.Context) error { + log.Info(ctx, "Listening for HTTP requests", j.KS("address", server.Addr)) + err := server.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + // NoReturnErr: Don't need to return this error + return nil + } + return err + }, + Shutdown: server.Shutdown, + } + return p +} + +// SecureHTTP integrates a secure http.Server as an App Process +func SecureHTTP(name string, server *http.Server, tlsCert, tlsKey string) lu.Process { + p := lu.Process{ + Name: "https " + name, + Run: func(ctx context.Context) error { + log.Info(ctx, "Listening for HTTPS requests", j.KS("address", server.Addr)) + err := server.ListenAndServeTLS(tlsCert, tlsKey) + if errors.Is(err, http.ErrServerClosed) { + // NoReturnErr: Don't need to return this error + return nil + } + return err + }, + Shutdown: server.Shutdown, + } + return p +} diff --git a/process/http_test.go b/process/http_test.go new file mode 100644 index 0000000..a592e9a --- /dev/null +++ b/process/http_test.go @@ -0,0 +1,39 @@ +package process + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/luno/jettison/jtest" + + "github.com/luno/lu" +) + +func TestProcess(t *testing.T) { + testCases := []struct { + name string + process lu.Process + }{ + { + name: "http server", + process: HTTP("test", &http.Server{Addr: "localhost:8080"}), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var a lu.App + a.AddProcess(tc.process) + + err := a.Launch(context.Background()) + jtest.AssertNil(t, err) + + time.Sleep(100 * time.Millisecond) + + err = a.Shutdown() + jtest.RequireNil(t, err) + }) + } +} diff --git a/process/loop.go b/process/loop.go new file mode 100644 index 0000000..6231f2f --- /dev/null +++ b/process/loop.go @@ -0,0 +1,158 @@ +package process + +import ( + "context" + "fmt" + "time" + + "github.com/go-stack/stack" + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/log" + "github.com/luno/jettison/trace" + + "github.com/luno/lu" +) + +// ErrBreakContextLoop acts as a translation error between the reflex domain and the lu process one. It will be +// returned as an alternative when (correctly configured) a reflex stream returns a reflex.ErrSteamToHead error. +var ErrBreakContextLoop = errors.New("the context loop has been stopped", j.C("ERR_f3833d51676ea908")) + +func defaultLoopOptions() options { + o := options{ + errorSleep: ErrorSleepFor(10 * time.Second), + // EXPERIMENTAL: Added for the purposes of production testing isolated cases with the new breakable behaviour + isBreakableLoop: false, + } + stk := trace.GetStackTrace(1, trace.StackConfig{ + RemoveLambdas: true, + PackagesHidden: []string{trace.PackagePath(lu.Process{})}, + TrimRuntime: true, + FormatStack: func(call stack.Call) string { + return fmt.Sprintf("%n", call) + }, + }) + if len(stk) > 0 { + o.name = stk[0] + } + return o +} + +func noOpContextFunc(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, nil +} + +// Loop is a Process that will repeatedly call f, logging errors until the process is cancelled. +func Loop(f lu.ProcessFunc, lo ...Option) lu.Process { + return ContextLoop(noOpContextFunc, f, lo...) +} + +// Retry runs the process function until it returns no error once. +func Retry(f lu.ProcessFunc, lo ...Option) lu.Process { + return ContextRetry(noOpContextFunc, f, lo...) +} + +// ContextLoop is a Process that will fetch a context and run f with that context. +// This can be used to block execution until a context is available. +func ContextLoop(getCtx ContextFunc, f lu.ProcessFunc, lo ...Option) lu.Process { + opts := resolveOptions(defaultLoopOptions(), lo) + return lu.Process{ + Name: opts.name, + Run: wrapContextLoop(getCtx, f, opts), + Shutdown: func(ctx context.Context) error { + return nil + }, + } +} + +func wrapContextLoop(getCtx ContextFunc, f lu.ProcessFunc, opts options) lu.ProcessFunc { + return func(ctx context.Context) error { + var errCount uint + for ctx.Err() == nil { + err := runWithContext(ctx, getCtx, func(ctx context.Context) error { + err := f(ctx) + sleep := opts.sleep() + if opts.isBreakableLoop && errors.Is(err, ErrBreakContextLoop) { + return err + } + if err != nil && !errors.IsAny(err, context.Canceled) { + // NoReturnErr: Log critical errors and continue loop + errCount += 1 + sleep = opts.errorSleep(errCount, err) + opts.errCounter.Inc() + log.Error(ctx, err) + if opts.maxErrors > 0 && errCount >= opts.maxErrors { + return err + } + } else { + errCount = 0 + } + if err = lu.Wait(ctx, opts.clock, sleep); err != nil { + opts.afterLoop() + return err + } + opts.afterLoop() + return nil + }) + if errors.Is(err, ErrBreakContextLoop) { + log.Info(ctx, "context loop terminated", log.WithError(err)) + return nil + } + if err != nil && !errors.Is(err, context.Canceled) { + // NOTE: Any error returned at this point will cause the entire App to terminate + return err + } + } + return ctx.Err() + } +} + +// ContextRetry runs the process function until it returns no error once. +func ContextRetry( + getCtx ContextFunc, + f lu.ProcessFunc, + callOpts ...Option, +) lu.Process { + opts := resolveOptions(defaultLoopOptions(), callOpts) + + var p lu.Process + p.Name = opts.name + p.Run = func(ctx context.Context) error { + var errCount uint + for ctx.Err() == nil { + err := runWithContext(ctx, getCtx, func(ctx context.Context) error { + err := f(ctx) + if err == nil { + return nil + } + + errCount += 1 + // NoReturnErr: Log critical errors and continue loop + if !errors.Is(err, context.Canceled) { + opts.errCounter.Inc() + log.Error(ctx, err) + } + sleep := opts.errorSleep(errCount, err) + if wErr := lu.Wait(ctx, opts.clock, sleep); wErr != nil { + return wErr + } + + return err + }) + if err == nil { + break + } + } + return ctx.Err() + } + return p +} + +func runWithContext(ctx context.Context, getCtx ContextFunc, f lu.ProcessFunc) error { + runCtx, cancel, err := getCtx(ctx) + if err != nil { + return err + } + defer cancel() + return f(runCtx) +} diff --git a/process/loop_break_test.go b/process/loop_break_test.go new file mode 100644 index 0000000..aa9d0fe --- /dev/null +++ b/process/loop_break_test.go @@ -0,0 +1,94 @@ +package process_test + +import ( + "context" + "testing" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/jtest" + "github.com/luno/jettison/log" + + "github.com/luno/lu" + "github.com/luno/lu/process" + "github.com/luno/lu/test" +) + +func TestLifecycle(t *testing.T) { + ev := make(test.EventLog, 100) + a := &lu.App{OnEvent: ev.Append} + + a.OnStartUp(func(ctx context.Context) error { + log.Info(ctx, "starting up") + return nil + }, lu.WithHookName("basic start hook")) + + a.OnShutdown(func(ctx context.Context) error { + log.Info(ctx, "stopping") + return nil + }, lu.WithHookName("basic stop hook")) + + a.AddProcess( + process.ContextLoop(noOpContextFunc(), noOpProcessFunc(), process.WithName("noop")), + process.ContextLoop(noOpContextFunc(), errProcessFunc(), process.WithName("error")), + process.ContextLoop(noOpContextFunc(), breakProcessFunc(), process.WithName("continue loop")), + process.ContextLoop(noOpContextFunc(), breakProcessFunc(), process.WithName("break loop"), process.WithBreakableLoop()), + ) + + err := a.Launch(context.Background()) + jtest.AssertNil(t, err) + + time.Sleep(250 * time.Millisecond) + + test.AssertEvents(t, ev, + test.Event{Type: lu.AppStartup}, + test.Event{Type: lu.PreHookStart, Name: "basic start hook"}, + test.Event{Type: lu.PostHookStart, Name: "basic start hook"}, + test.Event{Type: lu.AppRunning}, + test.AnyOrder( + test.Event{Type: lu.ProcessStart, Name: "noop"}, + test.Event{Type: lu.ProcessStart, Name: "error"}, + test.Event{Type: lu.ProcessStart, Name: "continue loop"}, + test.Event{Type: lu.ProcessStart, Name: "break loop"}, + test.Event{Type: lu.ProcessEnd, Name: "break loop"}, + ), + ) + + err = a.Shutdown() + jtest.AssertNil(t, err) + + close(ev) + test.AssertEvents(t, ev, + test.Event{Type: lu.AppTerminating}, + test.AnyOrder( + test.Event{Type: lu.ProcessEnd, Name: "noop"}, + test.Event{Type: lu.ProcessEnd, Name: "error"}, + test.Event{Type: lu.ProcessEnd, Name: "continue loop"}, + ), + test.Event{Type: lu.PreHookStop, Name: "basic stop hook"}, + test.Event{Type: lu.PostHookStop, Name: "basic stop hook"}, + test.Event{Type: lu.AppTerminated}, + ) +} + +func breakProcessFunc() func(context.Context) error { + return func(_ context.Context) error { return process.ErrBreakContextLoop } +} + +func errProcessFunc() func(context.Context) error { + return func(_ context.Context) error { + return errors.New("processing fail") + } +} + +func noOpProcessFunc() func(context.Context) error { + return func(_ context.Context) error { + return nil + } +} + +func noOpContextFunc() func(context.Context) (context.Context, context.CancelFunc, error) { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, nil + } +} diff --git a/process/loop_internal_test.go b/process/loop_internal_test.go new file mode 100644 index 0000000..000ff17 --- /dev/null +++ b/process/loop_internal_test.go @@ -0,0 +1,36 @@ +package process + +import ( + "context" + "testing" + + "github.com/luno/jettison/jtest" + "github.com/stretchr/testify/require" +) + +func Test_noopContextFunc(t *testing.T) { + testcases := []struct { + name string + ctx context.Context + }{ + { + name: "nil", + }, + { + name: "background", + ctx: context.Background(), + }, + { + name: "todo", + ctx: context.TODO(), + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + ctx, cf, err := noOpContextFunc(tc.ctx) + jtest.RequireNil(t, err) + require.Equal(t, tc.ctx, ctx) + require.NotNil(t, cf) + }) + } +} diff --git a/process/loop_test.go b/process/loop_test.go new file mode 100644 index 0000000..099f8f3 --- /dev/null +++ b/process/loop_test.go @@ -0,0 +1,355 @@ +package process_test + +import ( + "context" + "testing" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/jtest" + "github.com/stretchr/testify/assert" + "k8s.io/utils/clock" + clock_testing "k8s.io/utils/clock/testing" + + "github.com/luno/lu/process" +) + +func ctxRetry(ctx context.Context) (context.Context, context.CancelFunc, error) { + newCtx, cancelFunc := context.WithCancel(ctx) + return newCtx, cancelFunc, nil +} + +func alwaysSucceed() func(ctx context.Context) error { + return func(ctx context.Context) error { return nil } +} + +func failTimes(times int) func(ctx context.Context) error { + var failCount int + return func(ctx context.Context) error { + if failCount >= times { + return nil + } + failCount++ + return errors.New("failTimes", j.MKV{"fail_count": failCount}) + } +} + +func TestRetry_success(t *testing.T) { + ctx := context.Background() + p := process.Retry(alwaysSucceed()) + assert.Nil(t, p.Run(ctx)) +} + +func TestRetry_retries(t *testing.T) { + ctx := context.Background() + p := process.Retry(failTimes(3), process.WithErrorSleep(0)) + assert.Nil(t, p.Run(ctx)) +} + +func TestContextRetry_success(t *testing.T) { + ctx := context.Background() + p := process.ContextRetry(ctxRetry, alwaysSucceed()) + assert.Empty(t, p.Name) + assert.Nil(t, p.Shutdown) + assert.Nil(t, p.Run(ctx)) +} + +func TestContextRetry_retries(t *testing.T) { + ctx := context.Background() + fakeClock := &testClock{ + FakeClock: *clock_testing.NewFakeClock(time.Now()), + } + + errSleepTime := time.Second + p := process.ContextRetry(ctxRetry, failTimes(3), + process.WithName("retry-test"), + process.WithErrorSleep(errSleepTime), + process.WithClock(fakeClock), + process.WithSleep(time.Hour), // doen't get used in ContextRetry + ) + assert.Equal(t, p.Name, "retry-test") + assert.Nil(t, p.Shutdown) + assert.Nil(t, p.Run(ctx)) + assert.Equal(t, []time.Duration{errSleepTime, errSleepTime, errSleepTime}, + fakeClock.newTimerCalls, + "Expecting call to call clock.NewTimer 3 times, once for each failure") +} + +func TestContextRetry_cancelRoleContext(t *testing.T) { + ch := make(chan context.CancelFunc) + + fnGetRole := func(ctx context.Context) (context.Context, context.CancelFunc, error) { + ctx, cancel := context.WithCancel(ctx) + ch <- cancel + return ctx, cancel, nil + } + + p := process.ContextRetry( + fnGetRole, + func(ctx context.Context) error { + // Drop straight into sleep + return errors.New("some error") + }, + process.WithSleepFunc(func() time.Duration { + return time.Second * 10 + }), + ) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + _ = p.Run(ctx) + }() + + // Start one loop and send it to sleep + cancel1 := <-ch + // Cancel it, sending to the getCtx again + cancel1() + + select { + case nextCancel := <-ch: + t.Cleanup(nextCancel) + case <-time.After(time.Second): + assert.Fail(t, "timeout waiting for next getCtx") + } +} + +func TestContextRetry_cancelLuContext(t *testing.T) { + chStart := make(chan struct{}) + chDone := make(chan struct{}) + + fnGetRole := func(ctx context.Context) (context.Context, context.CancelFunc, error) { + ctx, cancel := context.WithCancel(ctx) + return ctx, cancel, nil + } + + p := process.ContextRetry( + fnGetRole, + func(ctx context.Context) error { + close(chStart) + // Drop straight into sleep + return errors.New("some error") + }, + process.WithSleepFunc(func() time.Duration { + return time.Second * 10 + }), + ) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + _ = p.Run(ctx) + close(chDone) + }() + + select { + case <-chStart: + case <-time.After(time.Second): + assert.Fail(t, "timeout wiating for process to start") + } + + time.Sleep(time.Second * 5) + + // Lu cancelles process + cancel() + + select { + case <-chStart: + case <-time.After(time.Second): + assert.Fail(t, "timeout waiting for next getCtx") + } +} + +// testClock is a clock.Clock implementation that returns a fakeTimer and keeps +// track of each call to NewTimer. +type testClock struct { + clock_testing.FakeClock + newTimerCalls []time.Duration +} + +func (f *testClock) NewTimer(d time.Duration) clock.Timer { + f.newTimerCalls = append(f.newTimerCalls, d) + return newFakeTime() +} + +// fakeTimer is a clock.Timer implementation that doesn't block and never +// reports that is has triggered. +type fakeTimer struct { + c <-chan time.Time +} + +func newFakeTime() *fakeTimer { + c := make(chan time.Time) + close(c) // close so it doesn't block + return &fakeTimer{c: c} +} + +func (t *fakeTimer) C() <-chan time.Time { + return t.c +} + +func (t *fakeTimer) Stop() bool { + return false +} + +func (t *fakeTimer) Reset(_ time.Duration) bool { + return false +} + +func TestContextLoopMaxError(t *testing.T) { + fail := errors.New("failure") + + testCases := []struct { + name string + maxErrors uint + returnErr error + expectedErr error + }{ + { + name: "run forever", + maxErrors: 0, + returnErr: fail, + expectedErr: context.Canceled, + }, + { + name: "fail after one", + maxErrors: 1, + returnErr: fail, + expectedErr: fail, + }, + + { + name: "fail after two", + maxErrors: 2, + returnErr: fail, + expectedErr: fail, + }, + + { + name: "fail after break", + maxErrors: 0, + returnErr: process.ErrBreakContextLoop, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var iteration int + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p := process.ContextLoop( + func(ctx context.Context) (context.Context, context.CancelFunc, error) { return ctx, func() {}, nil }, + func(ctx context.Context) error { + iteration = iteration + 1 + if iteration > 10 { + cancel() + } + + return tc.returnErr + }, + process.WithErrorSleep(0), + process.WithMaxErrors(tc.maxErrors), + process.WithBreakableLoop(), + ) + + err := p.Run(ctx) + jtest.Require(t, tc.expectedErr, err) + }) + } +} + +func TestErrorSleepConfig(t *testing.T) { + testCases := []struct { + name string + sleepFunc process.ErrorSleepFunc + expSleeps []time.Duration + }{ + { + name: "constant", + sleepFunc: process.MakeErrorSleepFunc(0, time.Second, nil), + expSleeps: []time.Duration{time.Second, time.Second, time.Second, time.Second, time.Second, time.Second}, + }, + { + name: "quick retries then falls back to sleep", + sleepFunc: process.MakeErrorSleepFunc(3, time.Second, nil), + expSleeps: []time.Duration{0, 0, 0, time.Second}, + }, + { + name: "default backoff", + sleepFunc: process.MakeErrorSleepFunc(0, 100*time.Millisecond, process.DefaultBackoff), + expSleeps: []time.Duration{ + 100 * time.Millisecond, + 200 * time.Millisecond, + 500 * time.Millisecond, + time.Second, + 2 * time.Second, + 5 * time.Second, + 10 * time.Second, + 10 * time.Second, + }, + }, + { + name: "backoff with retries", + sleepFunc: process.MakeErrorSleepFunc(2, time.Second, []uint{1, 2, 3}), + expSleeps: []time.Duration{ + 0, 0, + time.Second, 2 * time.Second, 3 * time.Second, + 3 * time.Second, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for idx, expDur := range tc.expSleeps { + dur := tc.sleepFunc(uint(idx+1), context.DeadlineExceeded) + assert.Equal(t, expDur, dur) + } + }) + } +} + +func TestSleepContextCancelled(t *testing.T) { + ch := make(chan context.CancelFunc) + + fnGetRole := func(ctx context.Context) (context.Context, context.CancelFunc, error) { + ctx, cancel := context.WithCancel(ctx) + ch <- cancel + return ctx, cancel, nil + } + + p := process.ContextLoop( + fnGetRole, + func(ctx context.Context) error { + // Drop straight into sleep + return nil + }, + process.WithSleepFunc(func() time.Duration { + return time.Second * 10 + }), + ) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + go func() { + _ = p.Run(ctx) + }() + + // Start one loop and send it to sleep + cancel1 := <-ch + // Cancel it, sending to the getCtx again + cancel1() + + select { + case nextCancel := <-ch: + t.Cleanup(nextCancel) + case <-time.After(time.Second): + assert.Fail(t, "timeout waiting for next getCtx") + } +} diff --git a/process/metrics.go b/process/metrics.go new file mode 100644 index 0000000..9904d2f --- /dev/null +++ b/process/metrics.go @@ -0,0 +1,28 @@ +package process + +import "github.com/prometheus/client_golang/prometheus" + +const processLabel = "process_name" + +// label returns the prometheus labels for the process +func label(name string) prometheus.Labels { + return prometheus.Labels{processLabel: name} +} + +// processErrors is the number of errors from processing events +var processErrors = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "lu_process_error_count", + Help: "Number of errors from running a process", +}, []string{processLabel}) + +var scheduleCursorLag = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: "lu_process_schedule_cursor_lag_seconds", + Help: "Number of seconds since the last successful run of a scheduled process when its cursor is lagging.", +}, []string{processLabel}) + +func init() { + prometheus.MustRegister( + processErrors, + scheduleCursorLag, + ) +} diff --git a/process/noop.go b/process/noop.go new file mode 100644 index 0000000..a1c4298 --- /dev/null +++ b/process/noop.go @@ -0,0 +1,18 @@ +package process + +import ( + "context" + + "github.com/luno/lu" +) + +// NoOp is a Process which doesn't do anything but runs until the app is terminated. +func NoOp() lu.Process { + return lu.Process{ + Name: "noop", + Run: func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }, + } +} diff --git a/process/options.go b/process/options.go new file mode 100644 index 0000000..70d66b3 --- /dev/null +++ b/process/options.go @@ -0,0 +1,185 @@ +package process + +import ( + "time" + + "github.com/prometheus/client_golang/prometheus" + "k8s.io/utils/clock" +) + +type options struct { + name string + // Any override role to await on instead of just waiting on the name of the service + role string + // Returns the time to sleep if no error occurs. Default 0. + sleep SleepFunc + // Config for the time to sleep if an error occurs. Defaults to a constant 10s. + errorSleep ErrorSleepFunc + maxErrors uint + clock clock.Clock + // Callback function that's called after a loop iteration but before the next iteration. + // It's for internal use only, and shouldn't be exposed outside this package. + // Default is a no-op. + afterLoop func() + + // Counts the errors for a specific process, the default increments the error counter metric in metrics.go with the process name as a label. + errCounter prometheus.Counter + + // EXPERIMENTAL: Added for the purposes of production testing isolated cases with the new breakable behaviour + // Flag to determine if we allow loops to break when an ErrBreakContextLoop is returned from the process function. + isBreakableLoop bool +} + +// SleepFunc returns how long to sleep between loops when there was no error. +type SleepFunc func() time.Duration + +// SleepFor returns a SleepFunc that returns a fixed sleep duration. +func SleepFor(dur time.Duration) SleepFunc { + return func() time.Duration { + return dur + } +} + +// ErrorSleepFunc returns how long to sleep when we encounter an error +// `errCount` is how many times we've had an error, always > 0 +// `err` is the latest error +// +// The function should not call time.Sleep itself, instead it +// should return the amount of time that will be used with lu.Wait +type ErrorSleepFunc func(errCount uint, err error) time.Duration + +// ErrorSleepFor will return the same amount of time for every error +func ErrorSleepFor(dur time.Duration) ErrorSleepFunc { + return func(uint, error) time.Duration { + return dur + } +} + +// MakeErrorSleepFunc specifies behaviour for how long to sleep when a function errors repeatedly. +// When error count is between 1 and r (1 <= c <= r) we will retry immediately. +// Then when error count is more than r, we will sleep for d. +// The backoff array is used as multipliers on d to determine the amount of sleep. +func MakeErrorSleepFunc(r uint, d time.Duration, backoff []uint) ErrorSleepFunc { + return func(errCount uint, err error) time.Duration { + if errCount <= r { + return 0 + } + if len(backoff) == 0 { + return d + } + backoffIdx := int(errCount) - 1 + if r > 0 { + backoffIdx -= int(r) + } + if backoffIdx >= len(backoff) { + backoffIdx = len(backoff) - 1 + } + return d * time.Duration(backoff[backoffIdx]) + } +} + +var DefaultBackoff = []uint{1, 2, 5, 10, 20, 50, 100} + +type Option func(*options) + +// resolveOptions applies the supplied LoopOptions to the defaults +func resolveOptions(defaults options, opts []Option) options { + res := defaults + for _, opt := range opts { + opt(&res) + } + if res.sleep == nil { + res.sleep = SleepFor(0) + } + if res.clock == nil { + res.clock = clock.RealClock{} + } + if res.errorSleep == nil { + res.errorSleep = ErrorSleepFor(10 * time.Second) + } + if res.afterLoop == nil { + res.afterLoop = func() {} + } + if res.errCounter == nil { + res.errCounter = processErrors.With(label(res.name)) + } + + return res +} + +func WithName(name string) Option { + return func(o *options) { + o.name = name + } +} + +// WithRole allows you to specify a custom role to await on when coordinating services which may be picked up by +// supporting lu Process builder like ReflexConsumer. +func WithRole(role string) Option { + return func(o *options) { + o.role = role + } +} + +// WithSleep is a shortcut for WithSleepFunc + SleepFor. +// The process will sleep for `d` on every successful loop. +func WithSleep(d time.Duration) Option { + return func(o *options) { + o.sleep = SleepFor(d) + } +} + +// WithSleepFunc sets the handler for determining how long ot sleep between loops when there was no error. +func WithSleepFunc(f SleepFunc) Option { + return func(o *options) { + o.sleep = f + } +} + +// WithErrorSleep is a shortcut for WithErrorSleepFunc + ErrorSleepFor +// The process will sleep for `d` on every error. +func WithErrorSleep(d time.Duration) Option { + return WithErrorSleepFunc(ErrorSleepFor(d)) +} + +// WithErrorSleepFunc sets the handler for determining how long to sleep for +// after an error. You can use ErrorSleepFor to sleep for a fixed amount of time: +// +// p := Loop(f, WithErrorSleepFunc(ErrorSleepFor(time.Minute))) +// +// or you can use MakeErrorSleepFunc to get some more complex behaviour +// +// p := Loop(f, WithErrorSleepFunc(MakeErrorSleepFunc(5, time.Minute, []uint{1,2,5,10}))) +func WithErrorSleepFunc(f ErrorSleepFunc) Option { + return func(o *options) { + o.errorSleep = f + } +} + +// WithClock overwrites the clock field with the value provided. +// Mainly used during testing. +func WithClock(clock clock.Clock) Option { + return func(o *options) { + o.clock = clock + } +} + +// WithMaxErrors sets the number errors that will cause us to give up +// on the currently running process. +// A value of 0 (the default) means we will never give up. +// A value of 1 means we give up after the first error, 2 the second and +// so on. +func WithMaxErrors(v uint) Option { + return func(o *options) { + o.maxErrors = v + } +} + +// WithBreakableLoop sets a flag that determines if when an ErrBreakContextLoop is returned +// from a process function if that context loop itself can be allowed to terminate as well. +// EXPERIMENTAL: Added for the purposes of production testing isolated cases with the new breakable behaviour +func WithBreakableLoop() Option { + return func(o *options) { + o.isBreakableLoop = true + } +} diff --git a/process/options_test.go b/process/options_test.go new file mode 100644 index 0000000..bd98c02 --- /dev/null +++ b/process/options_test.go @@ -0,0 +1,141 @@ +package process + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "k8s.io/utils/clock" + clocktesting "k8s.io/utils/clock/testing" +) + +func Test_ResolveOptions(t *testing.T) { + cc := []struct { + name string + defaults options + opts []Option + want options + }{ + { + name: "no options", + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "name", + opts: []Option{WithName("test-name")}, + want: options{ + name: "test-name", + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("test-name")), + }, + }, + { + name: "sleep", + opts: []Option{WithSleep(time.Hour)}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(time.Hour), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "error sleep", + opts: []Option{WithErrorSleepFunc(ErrorSleepFor(3 * time.Hour))}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(3 * time.Hour), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "clock is defaulted to valid value", + opts: []Option{WithClock(nil)}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "default is used", + defaults: options{errorSleep: ErrorSleepFor(time.Minute)}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(time.Minute), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "overriding default with bad value falls back to safe", + defaults: options{ + clock: &clocktesting.FakeClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(time.Minute), + errCounter: processErrors.With(label("")), + }, + opts: []Option{ + WithClock(nil), + WithErrorSleepFunc(nil), + }, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(0), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "negative sleep values", + opts: []Option{WithSleep(-time.Nanosecond), WithErrorSleepFunc(ErrorSleepFor(-time.Nanosecond))}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(-time.Nanosecond), + errorSleep: ErrorSleepFor(-time.Nanosecond), + errCounter: processErrors.With(label("")), + }, + }, + { + name: "sleep func", + opts: []Option{WithSleepFunc(func() time.Duration { return time.Hour })}, + want: options{ + clock: clock.RealClock{}, + sleep: SleepFor(time.Hour), + errorSleep: ErrorSleepFor(10 * time.Second), + errCounter: processErrors.With(label("")), + }, + }, + } + + for _, c := range cc { + t.Run(c.name, func(t *testing.T) { + o := resolveOptions(c.defaults, c.opts) + if c.want.sleep == nil { + assert.Nil(t, o.sleep) + } else { + assert.Equal(t, c.want.sleep(), o.sleep()) + c.want.sleep = nil + o.sleep = nil + } + if c.want.errorSleep == nil { + assert.Nil(t, o.errorSleep) + } else { + assert.Equal(t, c.want.errorSleep(1, nil), o.errorSleep(1, nil)) + c.want.errorSleep = nil + o.errorSleep = nil + } + o.afterLoop = nil + assert.Equal(t, c.want, o) + }) + } +} diff --git a/process/reflex.go b/process/reflex.go new file mode 100644 index 0000000..117cce3 --- /dev/null +++ b/process/reflex.go @@ -0,0 +1,119 @@ +package process + +import ( + "cmp" + "context" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/reflex" + "github.com/luno/reflex/rpatterns" + "k8s.io/utils/clock" + + "github.com/luno/lu" +) + +var defaultReflexOptions = options{ + sleep: SleepFor(100 * time.Millisecond), + errorSleep: ErrorSleepFor(time.Minute), + clock: clock.RealClock{}, +} + +type RunFunc func(in context.Context, s reflex.Spec) error + +// ReflexConsumer is the most standard function for generating a lu.Process that wraps a +// reflex consumer/stream loop. Unless you need the Particular temporary behaviour of a +// ReflexLiveConsumer or the multiplexing capability from a ManyReflexConsumer this should +// be your default choice to wait for a role, on a given consumer Spec, with any options +// that need to be defined. +func ReflexConsumer(awaitFunc AwaitRoleFunc, s reflex.Spec, ol ...Option) lu.Process { + return makeReflexProcess(awaitFunc, s, resolveOptions(defaultReflexOptions, ol)) +} + +// ManyReflexConsumers allows you to take a number of (probably related) specs and ensure that +// they all run on the same service instance against a given role (and all with the same set of options). +// Unlike the other ReflexConsumer generating functions it returns a slice of lu.Process with a +// cardinality directly related the size of the supplied specs parameter. +func ManyReflexConsumers(awaitFunc AwaitRoleFunc, specs []reflex.Spec, ol ...Option) []lu.Process { + opts := resolveOptions(defaultReflexOptions, ol) + ret := make([]lu.Process, 0, len(specs)) + for _, s := range specs { + ret = append(ret, makeReflexProcess(awaitFunc, s, opts)) + } + return ret +} + +// ReflexLiveConsumer will run a consumer on every instance of the service +// The stream will start from the latest event and the position is not restored on service restart. +func ReflexLiveConsumer(stream reflex.StreamFunc, consumer reflex.Consumer) lu.Process { + s := rpatterns.NewBootstrapSpec( + stream, + rpatterns.MemCursorStore(), + consumer, + reflex.WithStreamFromHead(), + ) + opts := resolveOptions(defaultReflexOptions, []Option{WithName(s.Name())}) + return makeContextProcess(noOpContextFunc, makeProcessFunc(s, reflex.Run), s, opts) +} + +// These two lu.Process generating functions handle the standard case with makeReflexProcess +// of generating breakable Reflex Consumer processes and makeContextProcess where we can provide +// an alternative lu ProcessFunc at the core of the process loop. In particular makeContextProcess +// allows us to handle the special case of a ReflexLiveConsumer but since its code is still wrapped +// by makeReflexProcess it is still the same core code that is run for all reflex Consumer processes. +// NOTE: This separation also exposed the internals to allow for simpler and better test coverage. + +// makeReflexProcess creates a looping lu.Process that will correctly handle breaking out of the loop +// configured with a reflex.WithStreamToHead() option which will return an error to show that the stream +// head has been reached and thus to consumer/stream can terminate. At its core it wraps the +// makeContextProcess but defines that the code can only execute if it can obtain a role and also +// ensures that the loop is always potentially breakable. +func makeReflexProcess(awaitFunc AwaitRoleFunc, s reflex.Spec, opts options) lu.Process { + rl := cmp.Or(opts.role, s.Name()) + return makeContextProcess(awaitFunc(rl), makeBreakableProcessFunc(s, reflex.Run), s, opts) +} + +// makeContextProcess is the core lu.Process generating function, it allows you to supply a +// ContextFunc that may or may require you to obtain a role (not in the case of a ReflexLiveConsumer) +// and an lu.ProcessFunc which allows you to supply a breakable or non-breakable instance (again +// none breakable in the case of a ReflexLiveConsumer) +func makeContextProcess(contextFunc ContextFunc, processFunc lu.ProcessFunc, s reflex.Spec, opts options) lu.Process { + opts.afterLoop = func() { _ = s.Stop() } + p := wrapContextLoop(contextFunc, processFunc, opts) + return lu.Process{Name: s.Name(), Run: p} +} + +// These two process functions handle the cases where we may wish to break out +// of a process loop (makeBreakableProcessFunc) or we can't break as for example +// we are only starting running from the cursor head. +// NOTE: This separation also exposed the internals to allow for simpler and better test coverage. + +// makeBreakableProcessFunc wraps makeProcessFunc to handle the special case of +// translating a reflex head reached error into an lu.Process ErrBreakContextLoop +// error so that for consumers configured with the reflex option reflex.WithStreamToHead() +// they can correctly terminate when the cursor head has been reached. +func makeBreakableProcessFunc(s reflex.Spec, run RunFunc) lu.ProcessFunc { + pf := makeProcessFunc(s, run) + return func(ctx context.Context) error { + err := pf(ctx) + if reflex.IsHeadReachedErr(err) { + return errors.Wrap(ErrBreakContextLoop, err.Error()) + } + return err + } +} + +// makeProcessFunc executes the given run function for the given spec and handles +// any expected reflex errors such as contexts being cancelled. However, it should not +// be used as the basis for process loops that may need to terminate early such as those +// configured with the reflex option reflex.WithStreamToHead() as unlike makeBreakableProcessFunc +// they will not return the correct error to let the stream/loop terminate. +func makeProcessFunc(s reflex.Spec, run RunFunc) lu.ProcessFunc { + return func(ctx context.Context) error { + err := run(ctx, s) + if reflex.IsExpected(err) { + return nil + } + return err + } +} diff --git a/process/reflex_test.go b/process/reflex_test.go new file mode 100644 index 0000000..e62b753 --- /dev/null +++ b/process/reflex_test.go @@ -0,0 +1,117 @@ +package process + +import ( + "context" + "testing" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/jtest" + "github.com/luno/reflex" + "github.com/luno/reflex/rpatterns" +) + +type stream struct{} + +func (s *stream) Recv() (*reflex.Event, error) { + return &reflex.Event{}, nil +} + +type headStream struct{} + +func (s *headStream) Recv() (*reflex.Event, error) { + return nil, reflex.ErrHeadReached +} + +type consumer struct { + cancel context.CancelFunc +} + +func (c *consumer) Name() string { return "test" } + +func (c *consumer) Consume(ctx context.Context, event *reflex.Event) error { + return errors.New("foo") +} + +func (c *consumer) Stop() error { + c.cancel() + return nil +} + +// Test_ReflexConsumer_afterLoop tests that afterLoop is called at the end of a +// context loop iteration. +func Test_ReflexConsumer_afterLoop(t *testing.T) { + awaitFunc := func(role string) func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, ctx.Err() + } + } + makeStream := func(ctx context.Context, after string, opts ...reflex.StreamOption) (reflex.StreamClient, error) { + return new(stream), nil + } + cstore := rpatterns.MemCursorStore() + c := new(consumer) + spec := reflex.NewSpec(makeStream, cstore, c) + process := ReflexConsumer(awaitFunc, spec, WithErrorSleep(0)) + ctx, cancel := context.WithCancel(context.Background()) + c.cancel = cancel // When afterLoop is called, the context will be cancelled and the context loop will end. + err := process.Run(ctx) + jtest.Require(t, context.Canceled, err) +} + +// Test_ReflexConsumer_breakLoop tests that the process run exits with a stream returns an ErrBreakContextLoop error +// when the stream Recv method returns reflex.ErrHeadReached i.e. a stream configured with the WithStreamToHead option. +func Test_ReflexConsumer_breakLoop(t *testing.T) { + awaitFunc := func(role string) func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, ctx.Err() + } + } + makeStream := func(ctx context.Context, after string, opts ...reflex.StreamOption) (reflex.StreamClient, error) { + return new(headStream), nil + } + cstore := rpatterns.MemCursorStore() + c := new(consumer) + spec := reflex.NewSpec(makeStream, cstore, c) + process := ReflexConsumer(awaitFunc, spec, WithErrorSleep(0), WithBreakableLoop()) + ctx, cancel := context.WithCancel(context.Background()) + c.cancel = cancel // When afterLoop is called, the context will be cancelled and the context loop will end. + err := process.Run(ctx) + jtest.RequireNil(t, err) +} + +func Test_makeBreakableProcessFunc(t *testing.T) { + ctx := context.Background() + processingErr := errors.New("Some Error") + testcases := []struct { + name string + run RunFunc + err error + }{ + { + name: "None: Nil", + run: func(_ context.Context, _ reflex.Spec) error { return nil }, + }, + { + name: "None: Expected: Stopped", + run: func(_ context.Context, _ reflex.Spec) error { return reflex.ErrStopped }, + }, + { + name: "Break Loop Error: ToHeadStream: Head Reached", + run: func(_ context.Context, _ reflex.Spec) error { return reflex.ErrHeadReached }, + err: ErrBreakContextLoop, + }, + { + name: "Error: Processing Error", + run: func(_ context.Context, _ reflex.Spec) error { return processingErr }, + err: processingErr, + }, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + var s reflex.Spec + p := makeBreakableProcessFunc(s, tc.run) + err := p(ctx) + jtest.Require(t, tc.err, err) + }) + } +} diff --git a/process/schedule.go b/process/schedule.go new file mode 100644 index 0000000..a8a5258 --- /dev/null +++ b/process/schedule.go @@ -0,0 +1,349 @@ +package process + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/log" + "github.com/robfig/cron/v3" + + "github.com/luno/lu" +) + +func defaultScheduleOptions() options { + return options{ + errorSleep: ErrorSleepFor(10 * time.Minute), + } +} + +// previousAware if a Schedule object implements the previousAware, it will use this method to determine +// when the last expected run was. This can be used to determine if there were missed intervals between +// the actual last run and the expected last run. +type previousAware interface { + Previous(now time.Time) time.Time +} + +// Schedule must return a time in the same time.Location as given to it in Next +type Schedule cron.Schedule + +type cronWithPrevious struct { + cron.Schedule +} + +const maxLookBack = 1000 * 24 * time.Hour + +func (c cronWithPrevious) Previous(now time.Time) time.Time { + lookBack := 10 * time.Minute + next := c.Next(now) + prev := next + for prev.Equal(next) { + if lookBack > maxLookBack { + return now + } + t := next.Add(-lookBack) + lookBack = lookBack * 2 + prev = c.Next(t) + } + t := prev + for !t.Equal(next) { + prev, t = t, c.Next(prev) + } + return prev +} + +func ParseCron(cronStr string) (Schedule, error) { + s, err := cron.ParseStandard(cronStr) + if err != nil { + return nil, err + } + return cronWithPrevious{Schedule: s}, nil +} + +type waitSchedule struct { + // Wait is the (minimum) duration between successful firings of this Schedule + Wait time.Duration +} + +func (r waitSchedule) Next(t time.Time) time.Time { + return t.Add(r.Wait) +} + +// Poll returns a schedule which runs on a given minimum delay (wait) between successful runs. +func Poll(wait time.Duration) Schedule { + return waitSchedule{Wait: wait} +} + +type EveryOption func(s *intervalSchedule) + +func WithOffset(offset time.Duration) EveryOption { + return func(s *intervalSchedule) { + s.Offset = offset + } +} + +func WithDescription(desc string) EveryOption { + return func(s *intervalSchedule) { + s.Description = desc + } +} + +// Every returns a schedule which returns a time equally spaced with a period. +// e.g. if period is time.Hour and Offset is 5*time.Minute then this schedule will return +// 12:05, 13:05, 14:05, etc... +// The time is truncated to the period based on unix time (see time.Truncate for details) +func Every(period time.Duration, opts ...EveryOption) Schedule { + return newIntervalSchedule(period, opts...) +} + +func newIntervalSchedule(period time.Duration, opts ...EveryOption) intervalSchedule { + s := intervalSchedule{Period: period} + for _, o := range opts { + o(&s) + } + return s +} + +type intervalSchedule struct { + // Description is a meaningful explanation of the particular IntervalSchedule + Description string + // Period is the duration between firings of this Interval + Period time.Duration + // Offset is the lag within the period before the first (and subsequent) firing of the Interval + Offset time.Duration +} + +func (r intervalSchedule) Next(t time.Time) time.Time { + next := t.Truncate(r.Period).Add(r.Offset) + if !next.After(t) { + next = next.Add(r.Period) + } + return next +} + +// FixedInterval unlike Every will execute on a specific interval only...regardless if the cursor has +// fallen behind. For example, if you specify a duration of 5 min...but your process stops running for 2 hours, the +// process will only execute at the next 5-min interval once, where process.Every will execute for all the missed 5-min intervals +// during the 2-hour outage. +func FixedInterval(period time.Duration, opts ...EveryOption) Schedule { + s := fixedIntervalSchedule{ + intervalSchedule: newIntervalSchedule(period, opts...), + } + + return s +} + +type fixedIntervalSchedule struct { + intervalSchedule +} + +// Previous this method returns the expected last run time. It uses this to compare with the +// actual last run time and ensure that the process only runs once for all the intervals in between the +// last run time and "now". +func (r fixedIntervalSchedule) Previous(now time.Time) time.Time { + prev := now.Truncate(r.Period).Add(r.Offset) + if prev.After(now) { + prev = prev.Add(-1 * r.Period) + } + + return prev +} + +// TimeOfDay returns a Schedule that will trigger at the same time every day +// hour is based on the 24-hour clock. +func TimeOfDay(hour, minute int) Schedule { + return timeOfDaySchedule{Hour: hour, Minute: minute} +} + +type timeOfDaySchedule struct { + Hour, Minute int +} + +func (s timeOfDaySchedule) Next(t time.Time) time.Time { + ti := time.Date( + t.Year(), t.Month(), t.Day(), + s.Hour, s.Minute, 0, 0, + t.Location(), + ) + if ti.After(t) { + return ti + } + return time.Date( + t.Year(), t.Month(), t.Day()+1, + s.Hour, s.Minute, 0, 0, + t.Location(), + ) +} + +// ToTimezone can be used when a schedule is to be run in a particular timezone. +// When using this with zones that observe daylight savings, it's important to be aware of the caveats around +// the boundaries of daylight savings - unit tests demonstrate times being skipped in some cases. +func ToTimezone(s cron.Schedule, tz *time.Location) cron.Schedule { + return tzSchedule{s: s, tz: tz} +} + +type tzSchedule struct { + s Schedule + tz *time.Location +} + +func (s tzSchedule) Next(t time.Time) time.Time { + nxt := s.s.Next(t.In(s.tz)) + return nxt.In(t.Location()) +} + +type ( + ContextFunc = func(ctx context.Context) (context.Context, context.CancelFunc, error) + AwaitRoleFunc = func(role string) ContextFunc + ScheduledFunc func(ctx context.Context, lastRunTime, runTime time.Time, runID string) error +) + +type Cursor interface { + Get(ctx context.Context, name string) (string, error) + Set(ctx context.Context, name string, value string) error +} + +// Scheduled will create a lu.Process which executes according to a Schedule +func Scheduled(awaitFunc AwaitRoleFunc, curs Cursor, + name string, when Schedule, f ScheduledFunc, + ol ...Option, +) lu.Process { + opts := resolveOptions(defaultScheduleOptions(), append(ol, WithName(name))) + + if opts.role == "" { + opts.role = opts.name + } + + runner := scheduleRunner{cursor: curs, o: opts, when: when, f: f} + process := func(ctx context.Context) time.Duration { return processOnce(ctx, awaitFunc, opts, &runner) } + wait := func(ctx context.Context, sleep time.Duration) error { return lu.Wait(ctx, opts.clock, sleep) } + loop := func(ctx context.Context) error { return processLoop(ctx, process, wait) } + + return lu.Process{ + Name: opts.name, + Run: loop, + } +} + +type ( + processFunc func(context.Context) time.Duration + waitFunc func(context.Context, time.Duration) error +) + +// processLoop may panic if processOnce or wait is nil. +func processLoop(ctx context.Context, process processFunc, wait waitFunc) error { + for ctx.Err() == nil { + sleep := process(ctx) + if err := wait(ctx, sleep); err != nil { + return err + } + } + return ctx.Err() +} + +// processOnce may panic if awaitRole is nil or if when calling it returns a nil role.ContextFunc, and +// it may also panic if opts.sleep or opts.errSleep are nil as well; which can be avoided by +// calling resolveOptions on the opts parameter before passing it into this function; it my also panic if +// runner.f is nil as well. +func processOnce(ctx context.Context, awaitRole AwaitRoleFunc, opts options, runner *scheduleRunner) time.Duration { + err := runWithContext(ctx, awaitRole(opts.role), runner.doNext) + sleep := opts.sleep() + if err != nil && !errors.Is(err, context.Canceled) { + // NoReturnErr: Log critical errors and continue loop + runner.ErrCount++ + sleep = opts.errorSleep(runner.ErrCount, err) + opts.errCounter.Inc() + log.Error(ctx, err) + } else { + runner.ErrCount = 0 + } + return sleep +} + +type scheduleRunner struct { + cursor Cursor + o options + when Schedule + f ScheduledFunc + + ErrCount uint +} + +// doNext executes the next iteration of the schedule. +// We use a cursor to keep track of the last completed run. +// If we miss running multiple runs of the cron then we will only attempt to run the latest one. +func (r scheduleRunner) doNext(ctx context.Context) error { + lastDone, err := getLastRun(ctx, r.cursor, r.o.name) + if err != nil { + return err + } + next := nextExecution(r.o.clock.Now(), lastDone, r.when, r.o.name) + + if r.o.maxErrors > 0 && r.ErrCount >= r.o.maxErrors { + return setRunDone(ctx, next, r.cursor, r.o.name) + } + + if err := lu.WaitUntil(ctx, r.o.clock, next); err != nil { + return err + } + + runID := fmt.Sprintf("%s_%d", r.o.name, next.Unix()) + + if err := r.f(ctx, lastDone, next, runID); err != nil { + return err + } + + return setRunDone(ctx, next, r.cursor, r.o.name) +} + +func nextExecution(now, last time.Time, s Schedule, name string) time.Time { + fromNow := s.Next(now) + if last.IsZero() { + return fromNow + } + + // If the expected last run does not match the actual last run, we will + // favour the expected last run if the schedule implements the right interface. + prev, ok := s.(previousAware) + if ok { + expectedLastRun := prev.Previous(now) + if !last.Equal(expectedLastRun) { + return expectedLastRun + } + } + + fromLast := s.Next(last) + if fromLast.Before(fromNow) { + scheduleCursorLag.WithLabelValues(name).Set(fromNow.Sub(fromLast).Seconds()) + return fromLast.In(now.Location()) + } + return fromNow +} + +// getLastRun returns the last successful run timestamp. +// Returns a zero time if no run is found. +func getLastRun(ctx context.Context, curs Cursor, name string) (time.Time, error) { + val, err := curs.Get(ctx, name) + if err != nil { + return time.Time{}, err + } + + if val == "" { + // Return zero time if no cursor. + return time.Time{}, nil + } + + unixSec, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return time.Time{}, err + } + + return time.Unix(unixSec, 0), nil +} + +func setRunDone(ctx context.Context, t time.Time, curs Cursor, name string) error { + unixSec := strconv.FormatInt(t.Unix(), 10) + return curs.Set(ctx, name, unixSec) +} diff --git a/process/schedule_test.go b/process/schedule_test.go new file mode 100644 index 0000000..6dd9f48 --- /dev/null +++ b/process/schedule_test.go @@ -0,0 +1,824 @@ +package process + +import ( + "context" + "testing" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/j" + "github.com/luno/jettison/jtest" + "github.com/robfig/cron/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + clocktesting "k8s.io/utils/clock/testing" +) + +type run struct { + runID string + lastRun time.Time + retError error +} + +var noRun run + +type expectRun struct { + t *testing.T + expRun run + alreadyRan bool +} + +type memCursor map[string]string + +func (m memCursor) Get(_ context.Context, name string) (string, error) { + return m[name], nil +} + +func (m memCursor) Set(_ context.Context, name string, value string) error { + m[name] = value + return nil +} + +//goland:noinspection GoExportedFuncWithUnexportedType +func ExpectRun(t *testing.T, run run) *expectRun { + return &expectRun{ + t: t, + expRun: run, + } +} + +func (r *expectRun) Run(_ context.Context, _, _ time.Time, runID string) error { + if r.alreadyRan { + r.t.Fatal("duplicate run") + } + assert.Equal(r.t, r.expRun.runID, runID) + r.alreadyRan = true + return r.expRun.retError +} + +func (r *expectRun) AssertUsed() { + assert.Equal(r.t, r.alreadyRan, r.expRun != noRun) +} + +func TestSchedule(t *testing.T) { + const ( + ts20220121Midnight = "1642723200" + ts20220122Midnight = "1642809600" + ts20220123Midnight = "1642896000" + ts20220123Exact = "1642944241" + ) + const cursorName = "test_schedule" + + testCases := []struct { + name string + + startTime time.Time + startCursor string + + when cron.Schedule + + setClockTo time.Time + + expRun run + expErr error + expCursor string + }{ + { + name: "run next with new cursor", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + + when: Every(24 * time.Hour), + + setClockTo: must(time.Parse(time.RFC3339, "2022-01-23T00:00:00Z")), + + expRun: run{runID: cursorName + "_" + ts20220123Midnight}, + expCursor: ts20220123Midnight, + }, + { + name: "poll next with new cursor", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + + when: Poll(24 * time.Hour), + + setClockTo: must(time.Parse(time.RFC3339, "2022-01-23T13:24:01Z")), + + expRun: run{runID: cursorName + "_" + ts20220123Exact}, + expCursor: ts20220123Exact, + }, + { + name: "run before the previous, runs immediately", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + startCursor: ts20220121Midnight, + + when: Every(24 * time.Hour), + + expRun: run{runID: cursorName + "_" + ts20220122Midnight}, + expCursor: ts20220122Midnight, + }, + { + name: "poll before the previous, runs immediately", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + startCursor: ts20220121Midnight, + + when: Poll(24 * time.Hour), + + expRun: run{runID: cursorName + "_" + ts20220122Midnight}, + expCursor: ts20220122Midnight, + }, + { + name: "run in the future blocks until cancelled", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + startCursor: ts20220122Midnight, + + when: Every(24 * time.Hour), + + expRun: noRun, + expErr: context.DeadlineExceeded, + expCursor: ts20220122Midnight, + }, + { + name: "poll in the future blocks until cancelled", + startTime: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + startCursor: ts20220122Midnight, + + when: Poll(24 * time.Hour), + + expRun: noRun, + expErr: context.DeadlineExceeded, + expCursor: ts20220122Midnight, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + t.Cleanup(cancel) + + cc := make(memCursor) + cl := clocktesting.NewFakeClock(tc.startTime) + + err := cc.Set(ctx, cursorName, tc.startCursor) + jtest.RequireNil(t, err) + + runs := ExpectRun(t, tc.expRun) + defer runs.AssertUsed() + + if !tc.setClockTo.IsZero() { + go func() { + for !cl.HasWaiters() { + time.Sleep(time.Millisecond) + } + cl.SetTime(tc.setClockTo) + }() + } + + r := scheduleRunner{ + cursor: cc, + o: options{name: cursorName, clock: cl}, + when: tc.when, + f: runs.Run, + } + jtest.Require(t, tc.expErr, r.doNext(ctx)) + + v, err := cc.Get(ctx, cursorName) + jtest.RequireNil(t, err) + assert.Equal(t, tc.expCursor, v) + }) + } +} + +func TestNextExecution(t *testing.T) { + testCases := []struct { + name string + + now time.Time + last time.Time + spec cron.Schedule + + expNext time.Time + }{ + { + name: "never ran, returns next one", + now: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + spec: Every(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T14:00:00Z")), + }, + { + name: "missed previous one, returns previous", + now: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + last: must(time.Parse(time.RFC3339, "2022-01-22T12:00:00Z")), + spec: Every(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T13:00:00Z")), + }, + { + name: "last equal to previous returns next", + now: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + last: must(time.Parse(time.RFC3339, "2022-01-22T13:00:00Z")), + spec: Every(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T14:00:00Z")), + }, + { + name: "last in the future still returns next", + now: must(time.Parse(time.RFC3339, "2022-01-22T13:24:01Z")), + last: must(time.Parse(time.RFC3339, "2022-01-22T13:44:00Z")), + spec: Every(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T14:00:00Z")), + }, + { + name: "offset handled", + now: must(time.Parse(time.RFC3339, "2022-01-22T15:04:53Z")), + last: must(time.Parse(time.RFC3339, "2022-01-22T14:10:00Z")), + spec: Every(time.Hour, WithOffset(10*time.Minute)), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T15:10:00Z")), + }, + { + name: "mixed timezones handles next run", + now: must(time.Parse(time.RFC3339, "2022-01-22T15:04:53+07:00")), + last: must(time.Parse(time.RFC3339, "2022-01-22T07:10:00Z")), + spec: Every(time.Hour, WithOffset(10*time.Minute)), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T15:10:00+07:00")), + }, + { + name: "mixed timezones handles previous run in now timezone", + now: must(time.Parse(time.RFC3339, "2022-01-22T15:04:53+07:00")), + last: must(time.Parse(time.RFC3339, "2022-01-22T06:10:00Z")), + spec: Every(time.Hour, WithOffset(10*time.Minute)), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T14:10:00+07:00")), + }, + { + name: "handle cron schedules", + now: must(time.Parse(time.RFC3339, "2022-01-21T15:04:53Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T14:00:00Z")), + spec: must(cron.ParseStandard("0 7,10,14 * * 1-5")), + expNext: must(time.Parse(time.RFC3339, "2022-01-24T07:00:00Z")), // 21st was a Friday, so should skip to Monday 24th + }, + { + name: "tod handles current run", + now: must(time.Parse(time.RFC3339, "2022-01-21T15:00:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T15:00:00Z")), + spec: TimeOfDay(15, 0), + expNext: must(time.Parse(time.RFC3339, "2022-01-22T15:00:00Z")), + }, + { + name: "fixed interval with cursor far in the past", + now: must(time.Parse(time.RFC3339, "2022-01-21T12:15:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T10:00:00Z")), + spec: FixedInterval(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-21T12:00:00Z")), + }, + { + name: "fixed interval with cursor in the past but now the same as expected run time", + now: must(time.Parse(time.RFC3339, "2022-01-21T12:00:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T10:00:00Z")), + spec: FixedInterval(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-21T12:00:00Z")), + }, + { + name: "fixed interval with cursor updated", + now: must(time.Parse(time.RFC3339, "2022-01-21T12:00:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T12:00:00Z")), + spec: FixedInterval(time.Hour), + expNext: must(time.Parse(time.RFC3339, "2022-01-21T13:00:00Z")), + }, + { + name: "fixed interval with offset", + now: must(time.Parse(time.RFC3339, "2022-01-21T12:20:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T08:00:00Z")), + spec: FixedInterval(time.Hour, WithOffset(time.Minute)), + expNext: must(time.Parse(time.RFC3339, "2022-01-21T12:01:00Z")), + }, + { + name: "fixed interval with historic cursor and offset run time and now value", + now: must(time.Parse(time.RFC3339, "2022-01-21T12:15:00Z")), + last: must(time.Parse(time.RFC3339, "2022-01-21T08:00:00Z")), + spec: FixedInterval(time.Hour, WithOffset(20*time.Minute)), + expNext: must(time.Parse(time.RFC3339, "2022-01-21T11:20:00Z")), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + next := nextExecution(tc.now, tc.last, tc.spec, "") + assert.Equal(t, tc.expNext, next) + }) + } +} + +func must[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +func TestNextExecutionMany(t *testing.T) { + timezoneAmericaNewYork, err := time.LoadLocation("America/New_York") + if err != nil { + require.Fail(t, "could not load location") + } + testCases := []struct { + name string + schedule cron.Schedule + start time.Time + end time.Time + expRuns []time.Time + }{ + { + name: "timezones", + schedule: ToTimezone(TimeOfDay(0, 30), timezoneAmericaNewYork), + start: time.Date(2022, 3, 10, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 15, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, 3, 10, 5, 30, 0, 0, time.UTC), + time.Date(2022, 3, 11, 5, 30, 0, 0, time.UTC), + time.Date(2022, 3, 12, 5, 30, 0, 0, time.UTC), + time.Date(2022, 3, 13, 5, 30, 0, 0, time.UTC), + time.Date(2022, 3, 14, 4, 30, 0, 0, time.UTC), + }, + }, + { + name: "timezones over DST switchover boundary (2AM 13 March, 2022)", + schedule: ToTimezone(TimeOfDay(2, 30), timezoneAmericaNewYork), + start: time.Date(2022, 3, 10, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 15, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, 3, 10, 7, 30, 0, 0, time.UTC), + time.Date(2022, 3, 11, 7, 30, 0, 0, time.UTC), + time.Date(2022, 3, 12, 7, 30, 0, 0, time.UTC), + time.Date(2022, 3, 13, 6, 30, 0, 0, time.UTC), + time.Date(2022, 3, 14, 6, 30, 0, 0, time.UTC), + }, + }, + { + name: "timezones with cron", + schedule: ToTimezone( + must(cron.ParseStandard("0 12,14 * * 1-5")), + timezoneAmericaNewYork, + ), + start: time.Date(2022, 3, 10, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 15, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, 3, 10, 17, 0, 0, 0, time.UTC), + time.Date(2022, 3, 10, 19, 0, 0, 0, time.UTC), + time.Date(2022, 3, 11, 17, 0, 0, 0, time.UTC), + time.Date(2022, 3, 11, 19, 0, 0, 0, time.UTC), + time.Date(2022, 3, 14, 16, 0, 0, 0, time.UTC), + time.Date(2022, 3, 14, 18, 0, 0, 0, time.UTC), + }, + }, + { + name: "timezones with cron over DST switchover boundary (2AM 13 March, 2022)", + schedule: ToTimezone( + must(cron.ParseStandard("30 2 * * *")), + timezoneAmericaNewYork, + ), + start: time.Date(2022, 3, 10, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 15, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, 3, 10, 7, 30, 0, 0, time.UTC), + time.Date(2022, 3, 11, 7, 30, 0, 0, time.UTC), + time.Date(2022, 3, 12, 7, 30, 0, 0, time.UTC), + // 2AM becomes 3AM on 13th at 2AM, so 2AM never happens on this day and the next run only happens the following day + time.Date(2022, 3, 14, 6, 30, 0, 0, time.UTC), + }, + }, + { + name: "timezones with cron running every afternoon over switchover into DST", + schedule: ToTimezone( + must(cron.ParseStandard("0 15 * * *")), + timezoneAmericaNewYork, + ), + start: time.Date(2022, 3, 10, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 15, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, 3, 10, 20, 0, 0, 0, time.UTC), + time.Date(2022, 3, 11, 20, 0, 0, 0, time.UTC), + time.Date(2022, 3, 12, 20, 0, 0, 0, time.UTC), + time.Date(2022, 3, 13, 19, 0, 0, 0, time.UTC), + time.Date(2022, 3, 14, 19, 0, 0, 0, time.UTC), + }, + }, + { + name: "timezones with short interval cron running every 15mins over switchover into DST", + schedule: ToTimezone( + must(cron.ParseStandard("45 * * * *")), + timezoneAmericaNewYork, + ), + start: time.Date(2022, 3, 13, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 3, 14, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2022, time.March, 13, 0, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 1, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 2, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 3, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 4, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 5, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 6, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 7, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 8, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 9, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 10, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 11, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 12, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 13, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 14, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 15, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 16, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 17, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 18, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 19, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 20, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 21, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 22, 45, 0, 0, time.UTC), + time.Date(2022, time.March, 13, 23, 45, 0, 0, time.UTC), + }, + }, + { + name: "new years day in foreign timezone, produces expected UTC execution time", + schedule: ToTimezone(TimeOfDay(0, 0), timezoneAmericaNewYork), + start: time.Date(2021, 12, 31, 0, 0, 0, 0, time.UTC), + end: time.Date(2022, 1, 2, 0, 0, 0, 0, time.UTC), + expRuns: []time.Time{ + time.Date(2021, 12, 31, 5, 0, 0, 0, time.UTC), + time.Date(2022, 1, 1, 5, 0, 0, 0, time.UTC), + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ti := tc.start + var runs []time.Time + for { + ti = tc.schedule.Next(ti) + if !ti.Before(tc.end) { + break + } + runs = append(runs, ti) + } + assert.Equal(t, tc.expRuns, runs) + }) + } +} + +func TestRetries(t *testing.T) { + errRun := errors.New("run error") + + testCases := []struct { + name string + maxErrors uint + errCount uint + + expWait bool + expErr error + expCursor string + }{ + { + name: "error on initial run", + maxErrors: 0, + errCount: 0, + expWait: true, + expErr: errRun, + }, + { + name: "error is retried", + maxErrors: 0, + errCount: 1, + expWait: true, + expErr: errRun, + }, + { + name: "error is not retried if max errors is set", + maxErrors: 1, + errCount: 1, + expCursor: "10020", + }, + { + name: "error on retry", + maxErrors: 5, + errCount: 4, + expWait: true, + expErr: errRun, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + clock := clocktesting.NewFakeClock(time.Unix(10_000, 0)) + cursor := make(memCursor) + o := options{ + name: "test_retry", + errorSleep: ErrorSleepFor(0), + maxErrors: tc.maxErrors, + clock: clock, + } + + r := scheduleRunner{ + cursor: cursor, + o: o, + when: Every(time.Minute), + f: func(_ context.Context, _, _ time.Time, _ string) error { + return errRun + }, + ErrCount: tc.errCount, + } + + if tc.expWait { + go step(clock, time.Minute) + } + + jtest.Assert(t, tc.expErr, r.doNext(context.Background())) + + v, err := cursor.Get(context.Background(), o.name) + jtest.RequireNil(t, err) + assert.Equal(t, tc.expCursor, v) + }) + } +} + +func step(clock *clocktesting.FakeClock, d time.Duration) { + for !clock.HasWaiters() { + time.Sleep(time.Millisecond) + } + clock.Step(d) +} + +type testContext struct { + errCalled int + err []error + t *testing.T +} + +func (*testContext) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*testContext) Done() <-chan struct{} { + return nil +} + +func (tc *testContext) Err() error { + l := len(tc.err) + if l <= tc.errCalled { + require.Fail(tc.t, "Insufficient context errors") + return nil + } + err := tc.err[tc.errCalled] + tc.errCalled++ + return err +} + +func (*testContext) Value(_ any) any { + return nil +} + +func (*testContext) String() string { + return "test error Context" +} + +func Test_processLoop(t *testing.T) { + process := func(context.Context) time.Duration { return time.Minute } + + tests := []struct { + name string + ctx context.Context + wait waitFunc + waitCalled int + err error + }{ + { + name: "bad context", + ctx: &testContext{err: []error{errors.New("ctx.Err()1!", j.C("err_1")), errors.New("ctx.Err()2!", j.C("err_2"))}}, + err: errors.New("ctx.Err()2!", j.C("err_2")), + }, + { + name: "wait function errors", + ctx: &testContext{err: []error{nil}}, + wait: func(_ context.Context, _ time.Duration) error { + return errors.New("Wait Error!", j.C("err_1")) + }, + waitCalled: 1, + err: errors.New("Wait Error!", j.C("err_1")), + }, + { + name: "Loop is cancelled after one process", + ctx: &testContext{err: []error{nil, context.Canceled, errors.New("Context Was Cancelled", j.C("err_1"))}}, + waitCalled: 1, + err: errors.New("Context Was Cancelled", j.C("err_1")), + }, + { + name: "Loop is cancelled after two process runs", + ctx: &testContext{err: []error{nil, nil, context.Canceled, errors.New("Context Was Cancelled", j.C("err_1"))}}, + waitCalled: 2, + err: errors.New("Context Was Cancelled", j.C("err_1")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + waitCalled := 0 + if tt.wait == nil { + tt.wait = func(_ context.Context, _ time.Duration) error { + return nil + } + } + wait := func(ctx context.Context, s time.Duration) error { + waitCalled++ + return tt.wait(ctx, s) + } + err := processLoop(tt.ctx, process, wait) + require.Equal(t, tt.waitCalled, waitCalled) + if tt.err == nil { + jtest.RequireNil(t, err) + } else { + jtest.Require(t, tt.err, err) + } + }) + } +} + +func Test_processOnce(t *testing.T) { + stdSleep := time.Minute * 10 + errSleep := time.Minute * 5 + tests := []struct { + name string + awaitRole AwaitRoleFunc + f ScheduledFunc + errCount uint + sleep time.Duration + }{ + { + name: "awaitRole returns context.Cancelled error", + awaitRole: func(_ string) ContextFunc { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, nil, context.Canceled + } + }, + sleep: stdSleep, + }, + { + name: "awaitRole returns non context.Cancelled error", + awaitRole: func(_ string) ContextFunc { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, nil, errors.New("Bang1!") + } + }, + sleep: errSleep, + errCount: 1, + }, + { + name: "runner.doNext returns context.Cancelled error", + f: func(_ context.Context, _, _ time.Time, _ string) error { + return context.Canceled + }, + sleep: stdSleep, + }, + { + name: "runner.doNext returns non context.Cancelled error", + f: func(_ context.Context, _, _ time.Time, _ string) error { + return errors.New("Bang1!") + }, + sleep: errSleep, + errCount: 1, + }, + { + name: "no errors returned", + sleep: stdSleep, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.awaitRole == nil { + tt.awaitRole = func(role string) ContextFunc { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, nil + } + } + } + if tt.f == nil { + tt.f = func(_ context.Context, _, _ time.Time, _ string) error { + return nil + } + } + r := scheduleRunner{ + cursor: make(memCursor), + o: options{name: "test_processFunc", clock: clocktesting.NewFakeClock(time.Unix(10_000, 0))}, + when: Poll(0), + f: tt.f, + } + var errCount uint + opts := options{ + sleep: SleepFor(stdSleep), + errorSleep: func(ec uint, err error) time.Duration { + errCount = ec + return errSleep + }, + } + opts = resolveOptions(opts, nil) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + t.Cleanup(cancel) + sleep := processOnce(ctx, tt.awaitRole, opts, &r) + require.Equal(t, tt.sleep, sleep) + require.Equal(t, tt.errCount, errCount) + _ = processOnce(ctx, tt.awaitRole, opts, &r) + // If there was no error, we still expect errCount=0. + // If there was an error, we expect another so errCount=2. + require.Equal(t, tt.errCount*2, errCount) + }) + } +} + +func TestLastScheduled(t *testing.T) { + tests := []struct { + name string + f ScheduledFunc + panics bool + }{ + { + name: "No panic: f LastScheduledFunc not nil", + f: func(_ context.Context, _, _ time.Time, _ string) error { return nil }, + }, + { + name: "Panic: f LastScheduledFunc is nil", + panics: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + t.Cleanup(cancel) + awaitRole := func(role string) ContextFunc { + return func(ctx context.Context) (context.Context, context.CancelFunc, error) { + return ctx, func() {}, nil + } + } + process := Scheduled(awaitRole, make(memCursor), "TestLastScheduled", Poll(1), tt.f) + tf := func() { _ = process.Run(ctx) } + if tt.panics { + require.Panics(t, tf) + } else { + require.NotPanics(t, tf) + } + }) + } +} + +func TestCronWithPrevious(t *testing.T) { + testCases := []struct { + name string + cron string + now time.Time + expPrevious time.Time + expNext time.Time + }{ + { + name: "cron that never runs, gives up", + cron: "0 0 31 2 *", + now: time.Date(2024, 1, 1, 0, 0, 59, 0, time.UTC), + expPrevious: time.Date(2024, 1, 1, 0, 0, 59, 0, time.UTC), + expNext: time.Time{}, + }, + { + name: "look back over a year", + cron: "1 1 1 1 *", + now: time.Date(2024, 1, 1, 0, 0, 59, 0, time.UTC), + expPrevious: time.Date(2023, 1, 1, 1, 1, 0, 0, time.UTC), + expNext: time.Date(2024, 1, 1, 1, 1, 0, 0, time.UTC), + }, + { + name: "daily at 9am", + cron: "0 9 * * *", + now: time.Date(2024, 10, 3, 8, 0, 0, 0, time.UTC), + expPrevious: time.Date(2024, 10, 2, 9, 0, 0, 0, time.UTC), + expNext: time.Date(2024, 10, 3, 9, 0, 0, 0, time.UTC), + }, + { + name: "every minute of every day", + cron: "* * * * *", + now: time.Date(2024, 10, 3, 8, 14, 45, 0, time.UTC), + expPrevious: time.Date(2024, 10, 3, 8, 14, 0, 0, time.UTC), + expNext: time.Date(2024, 10, 3, 8, 15, 0, 0, time.UTC), + }, + { + name: "every minute of every day, now is on schedule", + cron: "* * * * *", + now: time.Date(2024, 10, 3, 8, 14, 0, 0, time.UTC), + expPrevious: time.Date(2024, 10, 3, 8, 14, 0, 0, time.UTC), + expNext: time.Date(2024, 10, 3, 8, 15, 0, 0, time.UTC), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cr, err := cron.ParseStandard(tc.cron) + jtest.RequireNil(t, err) + s := cronWithPrevious{Schedule: cr} + + prev := s.Previous(tc.now) + assert.Equal(t, tc.expPrevious, prev) + + next := s.Next(tc.now) + assert.Equal(t, tc.expNext, next) + }) + } +} diff --git a/processable.go b/processable.go new file mode 100644 index 0000000..26f3414 --- /dev/null +++ b/processable.go @@ -0,0 +1,30 @@ +package lu + +import "context" + +type Processable interface { + func() | func(ctx context.Context) | func() error +} + +// WrapProcessFunc helper method to generate a ProcessFunc from +// an interface implementing the Processable methods +func WrapProcessFunc[P Processable](p P) ProcessFunc { + var x any = p + switch f := x.(type) { + case func(): + return func(ctx context.Context) error { + f() + return nil + } + case func(ctx context.Context): + return func(ctx context.Context) error { + f(ctx) + return nil + } + case func() error: + return func(ctx context.Context) error { + return f() + } + } + panic("unreachable") // Should never be reached due to constraint +} diff --git a/processable_test.go b/processable_test.go new file mode 100644 index 0000000..f18909b --- /dev/null +++ b/processable_test.go @@ -0,0 +1,44 @@ +package lu + +import ( + "context" + "testing" + + "github.com/luno/jettison/errors" + "github.com/stretchr/testify/require" +) + +func TestFuncMakeProcessFunc(t *testing.T) { + ctx := context.Background() + called := false + f := func() { called = true } + p := WrapProcessFunc(f) + require.NotNil(t, p) + err := p(ctx) + require.Nil(t, err) + require.True(t, called) +} + +func TestCtxFuncMakeProcessFunc(t *testing.T) { + ctx := context.Background() + called := false + call := &called + f := func(_ context.Context) { *call = true } + p := WrapProcessFunc(f) + require.NotNil(t, p) + err := p(ctx) + require.Nil(t, err) + require.True(t, called) +} + +func TestErrorFuncMakeProcessFunc(t *testing.T) { + ctx := context.Background() + called := false + call := &called + f := func() error { *call = true; return errors.New("dummy") } + p := WrapProcessFunc(f) + require.NotNil(t, p) + err := p(ctx) + require.NotNil(t, err) + require.True(t, called) +} diff --git a/signals.go b/signals.go new file mode 100644 index 0000000..8acef10 --- /dev/null +++ b/signals.go @@ -0,0 +1,84 @@ +package lu + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/luno/jettison/j" + "github.com/luno/jettison/log" +) + +// AppContext manages two contexts for running an app. It responds to different signals by +// cancelling one or both of these contexts. This behaviour allows us to do graceful shutdown +// in kubernetes using a stop script. If the application terminates before the stop script finishes +// then we get an error event from Kubernetes, so we need to be able to shut the application down +// using one signal, then exit the stop script and let Kubernetes send another signal to do the final +// termination. See this for more details on the hook behaviour +// https://kubernetes.io/docs/concepts/containers/container-lifecycle-hooks/ +// +// For SIGINT and SIGTERM, we will cancel both contexts, the application should +// finish all processes and call os.Exit +// +// For SIGQUIT, we cancel just the AppContext, the application should shut down all +// processes and wait for termination. +type AppContext struct { + signals chan os.Signal + + // AppContext should be used for running the application. + // When it's cancelled, the application should stop running all processes. + AppContext context.Context + appCancel context.CancelFunc + + // TerminationContext should be used for the execution of application. + // When it's cancelled the application binary should terminate. + // AppContext will be cancelled with this context as well. + TerminationContext context.Context + termCancel context.CancelFunc +} + +func NewAppContext(ctx context.Context) AppContext { + c := AppContext{ + signals: make(chan os.Signal, 1), + } + + c.TerminationContext, c.termCancel = context.WithCancel(ctx) + c.AppContext, c.appCancel = context.WithCancel(c.TerminationContext) + + signal.Notify(c.signals, syscall.SIGQUIT, syscall.SIGINT, syscall.SIGTERM) + + go c.monitor(ctx) + + return c +} + +func (c AppContext) Stop() { + signal.Stop(c.signals) + close(c.signals) +} + +func (c AppContext) monitor(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case s, ok := <-c.signals: + if !ok { + return + } + call, ok := s.(syscall.Signal) + if !ok { + log.Info(ctx, "received unknown OS signal", j.KV("signal", s)) + continue + } + log.Info(ctx, "received OS signal", j.KV("signal", call)) + switch call { + case syscall.SIGQUIT: + c.appCancel() + case syscall.SIGINT, syscall.SIGTERM: + c.termCancel() + } + } + } +} diff --git a/signals_test.go b/signals_test.go new file mode 100644 index 0000000..6c567f0 --- /dev/null +++ b/signals_test.go @@ -0,0 +1,88 @@ +package lu + +import ( + "context" + "syscall" + "testing" + "time" + + "github.com/luno/jettison/errors" + "github.com/luno/jettison/jtest" + "github.com/stretchr/testify/assert" +) + +func TestAppContext_QuitOnlyEndsTheAppContext(t *testing.T) { + ac := NewAppContext(context.Background()) + t.Cleanup(ac.Stop) + + ac.signals <- syscall.SIGQUIT + + assert.Eventually(t, func() bool { + return errors.Is(ac.AppContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) + + assert.Never(t, func() bool { + return errors.Is(ac.TerminationContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) +} + +func TestAppContext_IntEndsBothContexts(t *testing.T) { + ac := NewAppContext(context.Background()) + t.Cleanup(ac.Stop) + + ac.signals <- syscall.SIGINT + + assert.Eventually(t, func() bool { + return errors.Is(ac.AppContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) + + assert.Eventually(t, func() bool { + return errors.Is(ac.TerminationContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) +} + +func TestAppContext_QuitThenTerminate(t *testing.T) { + // This is the sequence of signals we will receive in kubernetes (when using the stop script) + ac := NewAppContext(context.Background()) + t.Cleanup(ac.Stop) + + ac.signals <- syscall.SIGQUIT + + assert.Eventually(t, func() bool { + return errors.Is(ac.AppContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) + + jtest.AssertNil(t, ac.TerminationContext.Err()) + + ac.signals <- syscall.SIGTERM + + assert.Eventually(t, func() bool { + return errors.Is(ac.TerminationContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) +} + +func TestAppContext_TerminateEndsEverything(t *testing.T) { + ac := NewAppContext(context.Background()) + t.Cleanup(ac.Stop) + + ac.signals <- syscall.SIGTERM + + assert.Eventually(t, func() bool { + return errors.Is(ac.AppContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) + + assert.Eventually(t, func() bool { + return errors.Is(ac.TerminationContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) +} + +func TestAppContext_CancelledContext(t *testing.T) { + ac := NewAppContext(context.Background()) + t.Cleanup(ac.Stop) + + ac.appCancel() + + assert.Eventually(t, func() bool { + return errors.Is(ac.AppContext.Err(), context.Canceled) + }, time.Second, time.Millisecond) +} diff --git a/test/test_checks.go b/test/test_checks.go new file mode 100644 index 0000000..3daade8 --- /dev/null +++ b/test/test_checks.go @@ -0,0 +1,49 @@ +package test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/luno/lu" +) + +// Only for testing purposes - do not import into main code builds + +func AnyOrder(events ...Event) EventConstraint { + left := make(map[lu.Event]int) + for _, ev := range events { + left[lu.Event(ev)]++ + } + return ConstraintFunc(func(t *testing.T, e lu.Event) bool { + l, ok := left[e] + require.True(t, ok, "unexpected event %+v", e) + assert.Greater(t, l, 0, "already got %+v", e) + left[e]-- + for _, v := range left { + if v > 0 { + return true + } + } + return false + }) +} + +func AssertEvents(t *testing.T, events chan lu.Event, constraints ...EventConstraint) { + var cIdx int + count := len(events) + for ev := range events { + t.Log("checking event", ev) + require.Less(t, cIdx, len(constraints), "additional unexpected event") + more := constraints[cIdx].CheckMore(t, ev) + if !more { + cIdx++ + } + count-- + if count == 0 { + break + } + } + assert.Equal(t, len(constraints), cIdx, "expected more events") +} diff --git a/test/test_types.go b/test/test_types.go new file mode 100644 index 0000000..c5f7a20 --- /dev/null +++ b/test/test_types.go @@ -0,0 +1,35 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/luno/lu" +) + +// Only for testing purposes - do not import into main code builds + +type EventLog chan lu.Event + +func (l EventLog) Append(_ context.Context, e lu.Event) { + l <- e +} + +type EventConstraint interface { + CheckMore(t *testing.T, e lu.Event) bool +} + +type Event lu.Event + +func (e Event) CheckMore(t *testing.T, got lu.Event) bool { + assert.Equal(t, lu.Event(e), got) + return false +} + +type ConstraintFunc func(t *testing.T, e lu.Event) bool + +func (f ConstraintFunc) CheckMore(t *testing.T, got lu.Event) bool { + return f(t, got) +}