Skip to content

Commit

Permalink
add tail call optimization for tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
martyvona committed Mar 2, 2024
1 parent 7029aeb commit 43959cf
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,19 @@ private <Return> void stepEffectModel(
this.tasks.put(task, progress.continueWith(s.continuation()));
this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime.plus(s.delay())));
} else if (status instanceof TaskStatus.CallingTask<Return> s) {
final var target = TaskId.generate();
SimulationEngine.this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor)));
SimulationEngine.this.taskParent.put(target, task);
SimulationEngine.this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target);
frame.signal(JobId.forTask(target));

this.tasks.put(task, progress.continueWith(s.continuation()));
this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target)));
if (s.tailCall()) {
this.tasks.put(task, new ExecutionState.InProgress<>(progress.startOffset, s.child().create(this.executor)));
this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime));
} else {
final var target = TaskId.generate();
this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor)));
this.taskParent.put(target, task);
this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target);
frame.signal(JobId.forTask(target));

this.tasks.put(task, progress.continueWith(s.continuation()));
this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target)));
}
} else if (status instanceof TaskStatus.AwaitingCondition<Return> s) {
final var condition = ConditionId.generate();
this.conditions.put(condition, s.condition());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ enum ContextType { Initializing, Reacting, Querying }
<Event> void emit(Event event, Topic<Event> topic);

void spawn(TaskFactory<?> task);

<Return> void call(TaskFactory<Return> task);

<Return> void tailCall(TaskFactory<Return> task);

void delay(Duration duration);

void waitUntil(Condition condition);
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public <Return> void call(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot yield during initialization");
}

@Override
public <Return> void tailCall(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot yield during initialization");
}

@Override
public void delay(final Duration duration) {
throw new IllegalStateException("Cannot yield during initialization");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ public static <T> void call(final TaskFactory<T> task) {
context.get().call(task);
}

public static void tailCall(final Runnable task) {
tailCall(threaded(task));
}

public static <T> void tailCall(final Supplier<T> task) {
tailCall(threaded(task));
}

public static <T> void tailCall(final TaskFactory<T> task) {
context.get().tailCall(task);
}

public static void defer(final Duration duration, final Runnable task) {
spawn(replaying(() -> { delay(duration); spawn(task); }));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public <Return> void call(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot schedule tasks in a query-only context");
}

@Override
public <Return> void tailCall(final TaskFactory<Return> task) {
throw new IllegalStateException("Cannot schedule tasks in a query-only context");
}

@Override
public void delay(final Duration duration) {
throw new IllegalStateException("Cannot yield in a query-only context");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public <T> void call(final TaskFactory<T> task) {
});
}

@Override
public <T> void tailCall(final TaskFactory<T> task) {
this.memory.doOnce(() -> {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
this.scheduler = this.handle.tailCall(task);
});
}

@Override
public void delay(final Duration duration) {
this.memory.doOnce(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ public Scheduler call(final TaskFactory<?> child) {
return this.yield(TaskStatus.calling(child, ReplayingTask.this));
}

@Override
public Scheduler tailCall(final TaskFactory<?> child) {
return this.yield(TaskStatus.tailCalling(child, ReplayingTask.this));
}

@Override
public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) {
return this.yield(TaskStatus.awaiting(condition, ReplayingTask.this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ public interface TaskHandle {

Scheduler call(TaskFactory<?> child);

Scheduler tailCall(TaskFactory<?> child);

Scheduler await(gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition);
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ public <T> void call(final TaskFactory<T> task) {
this.scheduler = this.handle.call(task);
}

@Override
public <T> void tailCall(final TaskFactory<T> task) {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
this.scheduler = this.handle.tailCall(task);
}

@Override
public void delay(final Duration duration) {
this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ public Scheduler call(final TaskFactory<?> child) {
return this.yield(TaskStatus.calling(child, ThreadedTask.this));
}

@Override
public Scheduler tailCall(final TaskFactory<?> child) {
return this.yield(TaskStatus.tailCalling(child, ThreadedTask.this));
}

@Override
public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) {
return this.yield(TaskStatus.awaiting(condition, ThreadedTask.this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ record Completed<Return>(Return returnValue) implements TaskStatus<Return> {}

record Delayed<Return>(Duration delay, Task<Return> continuation) implements TaskStatus<Return> {}

record CallingTask<Return>(TaskFactory<?> child, Task<Return> continuation) implements TaskStatus<Return> {}
record CallingTask<Return>(TaskFactory<?> child, Task<Return> continuation, boolean tailCall)
implements TaskStatus<Return> {}

record AwaitingCondition<Return>(Condition condition, Task<Return> continuation) implements TaskStatus<Return> {}

Expand All @@ -23,7 +24,11 @@ static <Return> Delayed<Return> delayed(final Duration delay, final Task<Return>
}

static <Return> CallingTask<Return> calling(final TaskFactory<?> child, final Task<Return> continuation) {
return new CallingTask<>(child, continuation);
return new CallingTask<>(child, continuation, false);
}

static <Return> CallingTask<Return> tailCalling(final TaskFactory<?> child, final Task<Return> continuation) {
return new CallingTask<>(child, continuation, true);
}

static <Return> AwaitingCondition<Return> awaiting(final Condition condition, final Task<Return> continuation) {
Expand Down

0 comments on commit 43959cf

Please sign in to comment.