Skip to content

Commit

Permalink
Simplify Task. Add support for different tickers
Browse files Browse the repository at this point in the history
  • Loading branch information
swift1337 committed Dec 20, 2024
1 parent 2d0241e commit 60c166d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 79 deletions.
20 changes: 10 additions & 10 deletions pkg/scheduler/opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,40 @@ import (
)

// Opt Task option
type Opt func(task *Task)
type Opt func(task *Task, taskOpts *taskOpts)

// Name sets task name.
func Name(name string) Opt {
return func(d *Task) { d.name = name }
return func(t *Task, _ *taskOpts) { t.name = name }
}

// GroupName sets task group. Otherwise, defaults to DefaultGroup.
func GroupName(group Group) Opt {
return func(d *Task) { d.group = group }
return func(t *Task, _ *taskOpts) { t.group = group }
}

// LogFields augments Task's logger with some fields.
func LogFields(fields map[string]any) Opt {
return func(d *Task) { d.logFields = fields }
return func(_ *Task, opts *taskOpts) { opts.logFields = fields }
}

// Interval sets initial task interval.
func Interval(interval time.Duration) Opt {
return func(d *Task) { d.interval = interval }
return func(_ *Task, opts *taskOpts) { opts.interval = interval }
}

// Skipper sets task skipper function
func Skipper(skipper func() bool) Opt {
return func(d *Task) { d.skipper = skipper }
return func(t *Task, _ *taskOpts) { t.skipper = skipper }
}

// IntervalUpdater sets interval updater function.
func IntervalUpdater(intervalUpdater func() time.Duration) Opt {
return func(d *Task) { d.intervalUpdater = intervalUpdater }
return func(_ *Task, opts *taskOpts) { opts.intervalUpdater = intervalUpdater }
}

// BlockTicker makes Definition to listen for new zeta blocks instead of using interval ticker.
// IntervalUpdater is ignored.
// BlockTicker makes Task to listen for new zeta blocks
// instead of using interval ticker. IntervalUpdater is ignored.
func BlockTicker(blocks <-chan cometbft.EventDataNewBlock) Opt {
return func(d *Task) { d.blockChan = blocks }
return func(_ *Task, opts *taskOpts) { opts.blockChan = blocks }
}
122 changes: 53 additions & 69 deletions pkg/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/rs/zerolog"

"github.com/zeta-chain/node/pkg/bg"
"github.com/zeta-chain/node/pkg/ticker"
)

// Scheduler represents background task scheduler.
Expand All @@ -35,31 +34,37 @@ type Group string
// DefaultGroup is the default task group.
const DefaultGroup = Group("default")

// tickable ticker abstraction to support different implementations
type tickable interface {
Start(ctx context.Context) error
Stop()
}

// Task represents scheduler's task.
type Task struct {
// ref to the Scheduler is required
scheduler *Scheduler

// naming stuff
id uuid.UUID
group Group
name string

exec Executable

// represents interval ticker and its options
ticker *ticker.Ticker
// ticker abstraction to support different implementations
ticker tickable
skipper func() bool

logger zerolog.Logger
}

type taskOpts struct {
interval time.Duration
intervalUpdater func() time.Duration
skipper func() bool

// zeta block ticker (also supports skipper)
blockChan <-chan cometbft.EventDataNewBlock
blockChanTicker *blockTicker
blockChan <-chan cometbft.EventDataNewBlock

// logging
logFields map[string]any
logger zerolog.Logger
}

// New Scheduler instance.
Expand All @@ -79,15 +84,21 @@ func (s *Scheduler) Register(ctx context.Context, exec Executable, opts ...Opt)
group: DefaultGroup,
name: id.String(),
exec: exec,
interval: time.Second,
}

config := &taskOpts{
interval: time.Second,
}

for _, opt := range opts {
opt(task)
opt(task, config)
}

task.logger = newTaskLogger(task, s.logger)
task.logger = newTaskLogger(task, config, s.logger)
task.ticker = newTickable(task, config)

task.startTicker(ctx)
task.logger.Info().Msg("Starting scheduler task")
bg.Work(ctx, task.ticker.Start, bg.WithLogger(task.logger))

s.mu.Lock()
s.tasks[id] = task
Expand Down Expand Up @@ -137,63 +148,21 @@ func (s *Scheduler) StopGroup(group Group) {

// Stop stops the task and offloads it from the scheduler.
func (t *Task) Stop() {
start := time.Now()

// delete task from scheduler
defer func() {
t.scheduler.mu.Lock()
delete(t.scheduler.tasks, t.id)
t.scheduler.mu.Unlock()

timeTakenMS := time.Since(start).Milliseconds()
t.logger.Info().Int64("time_taken_ms", timeTakenMS).Msg("Stopped scheduler task")
}()

t.logger.Info().Msg("Stopping scheduler task")
start := time.Now()

if t.isIntervalTicker() {
t.ticker.StopBlocking()
return
}

t.blockChanTicker.Stop()
}

func (t *Task) isIntervalTicker() bool {
return t.blockChan == nil
}

func (t *Task) startTicker(ctx context.Context) {
t.logger.Info().Msg("Starting scheduler task")

if t.isIntervalTicker() {
t.ticker = ticker.New(t.interval, t.invokeByInterval, ticker.WithLogger(t.logger, t.name))
bg.Work(ctx, t.ticker.Start, bg.WithLogger(t.logger))

return
}

t.blockChanTicker = newBlockTicker(t.invoke, t.blockChan, t.logger)

bg.Work(ctx, t.blockChanTicker.Start, bg.WithLogger(t.logger))
}

// invokeByInterval a ticker.Task wrapper of invoke.
func (t *Task) invokeByInterval(ctx context.Context, tt *ticker.Ticker) error {
if err := t.invoke(ctx); err != nil {
t.logger.Error().Err(err).Msg("task failed")
}
t.ticker.Stop()

if t.intervalUpdater != nil {
// noop if interval is not changed
tt.SetInterval(t.intervalUpdater())
}
t.scheduler.mu.Lock()
delete(t.scheduler.tasks, t.id)
t.scheduler.mu.Unlock()

return nil
timeTakenMS := time.Since(start).Milliseconds()
t.logger.Info().Int64("time_taken_ms", timeTakenMS).Msg("Stopped scheduler task")
}

// invoke executes a given Task with logging & telemetry.
func (t *Task) invoke(ctx context.Context) error {
// execute executes Task with additional logging and metrics.
func (t *Task) execute(ctx context.Context) error {
// skip tick
if t.skipper != nil && t.skipper() {
return nil
Expand All @@ -214,19 +183,34 @@ func (t *Task) invoke(ctx context.Context) error {
return err
}

func newTaskLogger(task *Task, logger zerolog.Logger) zerolog.Logger {
func newTaskLogger(task *Task, opts *taskOpts, logger zerolog.Logger) zerolog.Logger {
logOpts := logger.With().
Str("task.name", task.name).
Str("task.group", string(task.group))

if len(task.logFields) > 0 {
logOpts = logOpts.Fields(task.logFields)
if len(opts.logFields) > 0 {
logOpts = logOpts.Fields(opts.logFields)
}

taskType := "interval_ticker"
if task.blockChanTicker != nil {
if opts.blockChan != nil {
taskType = "block_ticker"
}

return logOpts.Str("task.type", taskType).Logger()
}

func newTickable(task *Task, opts *taskOpts) tickable {
// Block-based ticker
if opts.blockChan != nil {
return newBlockTicker(task.execute, opts.blockChan, task.logger)
}

return newIntervalTicker(
task.execute,
opts.interval,
opts.intervalUpdater,
task.name,
task.logger,
)
}
1 change: 1 addition & 0 deletions pkg/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func TestScheduler(t *testing.T) {
// ASSERT
assert.Equal(t, int64(21), counter)
assert.Contains(t, ts.logBuffer.String(), "Stopped scheduler task")
assert.Contains(t, ts.logBuffer.String(), `"task.type":"block_ticker"`)
})

t.Run("Block tick: tick is slower than the block", func(t *testing.T) {
Expand Down

0 comments on commit 60c166d

Please sign in to comment.