From 43959cf2dcc081f7dbda7058579aba771b2dbab0 Mon Sep 17 00:00:00 2001 From: Marsette Vona Date: Thu, 8 Feb 2024 09:37:22 -0500 Subject: [PATCH] add tail call optimization for tasks --- .../driver/engine/SimulationEngine.java | 21 ++++++++++++------- .../jpl/aerie/merlin/framework/Context.java | 4 ++++ .../framework/InitializationContext.java | 5 +++++ .../aerie/merlin/framework/ModelActions.java | 12 +++++++++++ .../aerie/merlin/framework/QueryContext.java | 5 +++++ .../framework/ReplayingReactionContext.java | 8 +++++++ .../aerie/merlin/framework/ReplayingTask.java | 5 +++++ .../aerie/merlin/framework/TaskHandle.java | 2 ++ .../framework/ThreadedReactionContext.java | 6 ++++++ .../aerie/merlin/framework/ThreadedTask.java | 5 +++++ .../merlin/protocol/types/TaskStatus.java | 9 ++++++-- 11 files changed, 72 insertions(+), 10 deletions(-) diff --git a/merlin-driver/src/main/java/gov/nasa/jpl/aerie/merlin/driver/engine/SimulationEngine.java b/merlin-driver/src/main/java/gov/nasa/jpl/aerie/merlin/driver/engine/SimulationEngine.java index 2a5eb170c7..bfc2a87ec1 100644 --- a/merlin-driver/src/main/java/gov/nasa/jpl/aerie/merlin/driver/engine/SimulationEngine.java +++ b/merlin-driver/src/main/java/gov/nasa/jpl/aerie/merlin/driver/engine/SimulationEngine.java @@ -219,14 +219,19 @@ private void stepEffectModel( this.tasks.put(task, progress.continueWith(s.continuation())); this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime.plus(s.delay()))); } else if (status instanceof TaskStatus.CallingTask s) { - final var target = TaskId.generate(); - SimulationEngine.this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor))); - SimulationEngine.this.taskParent.put(target, task); - SimulationEngine.this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target); - frame.signal(JobId.forTask(target)); - - this.tasks.put(task, progress.continueWith(s.continuation())); - this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target))); + if (s.tailCall()) { + this.tasks.put(task, new ExecutionState.InProgress<>(progress.startOffset, s.child().create(this.executor))); + this.scheduledJobs.schedule(JobId.forTask(task), SubInstant.Tasks.at(currentTime)); + } else { + final var target = TaskId.generate(); + this.tasks.put(target, new ExecutionState.InProgress<>(currentTime, s.child().create(this.executor))); + this.taskParent.put(target, task); + this.taskChildren.computeIfAbsent(task, $ -> new HashSet<>()).add(target); + frame.signal(JobId.forTask(target)); + + this.tasks.put(task, progress.continueWith(s.continuation())); + this.waitingTasks.subscribeQuery(task, Set.of(SignalId.forTask(target))); + } } else if (status instanceof TaskStatus.AwaitingCondition s) { final var condition = ConditionId.generate(); this.conditions.put(condition, s.condition()); diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/Context.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/Context.java index 6738237a37..8d90e41e8b 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/Context.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/Context.java @@ -30,8 +30,12 @@ enum ContextType { Initializing, Reacting, Querying } void emit(Event event, Topic topic); void spawn(TaskFactory task); + void call(TaskFactory task); + void tailCall(TaskFactory task); + void delay(Duration duration); + void waitUntil(Condition condition); } diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/InitializationContext.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/InitializationContext.java index 9615830a69..adccdfc2c3 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/InitializationContext.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/InitializationContext.java @@ -61,6 +61,11 @@ public void call(final TaskFactory task) { throw new IllegalStateException("Cannot yield during initialization"); } + @Override + public void tailCall(final TaskFactory task) { + throw new IllegalStateException("Cannot yield during initialization"); + } + @Override public void delay(final Duration duration) { throw new IllegalStateException("Cannot yield during initialization"); diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ModelActions.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ModelActions.java index 1c9019f328..5372c20fa4 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ModelActions.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ModelActions.java @@ -69,6 +69,18 @@ public static void call(final TaskFactory task) { context.get().call(task); } + public static void tailCall(final Runnable task) { + tailCall(threaded(task)); + } + + public static void tailCall(final Supplier task) { + tailCall(threaded(task)); + } + + public static void tailCall(final TaskFactory task) { + context.get().tailCall(task); + } + public static void defer(final Duration duration, final Runnable task) { spawn(replaying(() -> { delay(duration); spawn(task); })); } diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/QueryContext.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/QueryContext.java index 2425a9ab77..f20cb197e6 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/QueryContext.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/QueryContext.java @@ -52,6 +52,11 @@ public void call(final TaskFactory task) { throw new IllegalStateException("Cannot schedule tasks in a query-only context"); } + @Override + public void tailCall(final TaskFactory task) { + throw new IllegalStateException("Cannot schedule tasks in a query-only context"); + } + @Override public void delay(final Duration duration) { throw new IllegalStateException("Cannot yield in a query-only context"); diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingReactionContext.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingReactionContext.java index e9998ff256..f093dba2a9 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingReactionContext.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingReactionContext.java @@ -78,6 +78,14 @@ public void call(final TaskFactory task) { }); } + @Override + public void tailCall(final TaskFactory task) { + this.memory.doOnce(() -> { + this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown. + this.scheduler = this.handle.tailCall(task); + }); + } + @Override public void delay(final Duration duration) { this.memory.doOnce(() -> { diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingTask.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingTask.java index b18f38f6ca..0ad2ad3dbd 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingTask.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ReplayingTask.java @@ -56,6 +56,11 @@ public Scheduler call(final TaskFactory child) { return this.yield(TaskStatus.calling(child, ReplayingTask.this)); } + @Override + public Scheduler tailCall(final TaskFactory child) { + return this.yield(TaskStatus.tailCalling(child, ReplayingTask.this)); + } + @Override public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) { return this.yield(TaskStatus.awaiting(condition, ReplayingTask.this)); diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/TaskHandle.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/TaskHandle.java index 609453f0c9..4afe2da452 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/TaskHandle.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/TaskHandle.java @@ -9,5 +9,7 @@ public interface TaskHandle { Scheduler call(TaskFactory child); + Scheduler tailCall(TaskFactory child); + Scheduler await(gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition); } diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedReactionContext.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedReactionContext.java index 3edd568d9e..3f72af4edd 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedReactionContext.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedReactionContext.java @@ -63,6 +63,12 @@ public void call(final TaskFactory task) { this.scheduler = this.handle.call(task); } + @Override + public void tailCall(final TaskFactory task) { + this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown. + this.scheduler = this.handle.tailCall(task); + } + @Override public void delay(final Duration duration) { this.scheduler = null; // Relinquish the current scheduler before yielding, in case an exception is thrown. diff --git a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedTask.java b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedTask.java index 58e3e3b626..84c54aaade 100644 --- a/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedTask.java +++ b/merlin-framework/src/main/java/gov/nasa/jpl/aerie/merlin/framework/ThreadedTask.java @@ -207,6 +207,11 @@ public Scheduler call(final TaskFactory child) { return this.yield(TaskStatus.calling(child, ThreadedTask.this)); } + @Override + public Scheduler tailCall(final TaskFactory child) { + return this.yield(TaskStatus.tailCalling(child, ThreadedTask.this)); + } + @Override public Scheduler await(final gov.nasa.jpl.aerie.merlin.protocol.model.Condition condition) { return this.yield(TaskStatus.awaiting(condition, ThreadedTask.this)); diff --git a/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/TaskStatus.java b/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/TaskStatus.java index 178b320ba4..2975aba538 100644 --- a/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/TaskStatus.java +++ b/merlin-sdk/src/main/java/gov/nasa/jpl/aerie/merlin/protocol/types/TaskStatus.java @@ -9,7 +9,8 @@ record Completed(Return returnValue) implements TaskStatus {} record Delayed(Duration delay, Task continuation) implements TaskStatus {} - record CallingTask(TaskFactory child, Task continuation) implements TaskStatus {} + record CallingTask(TaskFactory child, Task continuation, boolean tailCall) + implements TaskStatus {} record AwaitingCondition(Condition condition, Task continuation) implements TaskStatus {} @@ -23,7 +24,11 @@ static Delayed delayed(final Duration delay, final Task } static CallingTask calling(final TaskFactory child, final Task continuation) { - return new CallingTask<>(child, continuation); + return new CallingTask<>(child, continuation, false); + } + + static CallingTask tailCalling(final TaskFactory child, final Task continuation) { + return new CallingTask<>(child, continuation, true); } static AwaitingCondition awaiting(final Condition condition, final Task continuation) {