Skip to content

Commit

Permalink
Support testing dynamic workflow task
Browse files Browse the repository at this point in the history
Signed-off-by: Hongxin Liang <[email protected]>
  • Loading branch information
honnix committed Sep 28, 2023
1 parent 4263190 commit ac416d9
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ public SumWorkflow.Output expand(SdkWorkflowBuilder builder, Input input) {
.result();
SdkBindingData<Long> abcd =
builder.apply("post-sum", new SumTask(), SumTask.SumInput.create(abc, d)).getOutputs();
return SumWorkflow.Output.create(abcd);
SdkBindingData<Long> result =
builder
.apply(
"fibonacci",
new DynamicFibonacciWorkflowTask(),
DynamicFibonacciWorkflowTask.Input.create(abcd))
.getOutputs()
.output();
return SumWorkflow.Output.create(result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,13 @@ public void testMockTasks() {
new SumTask(),
SumTask.SumInput.create(SdkBindingDataFactory.of(0L), SdkBindingDataFactory.of(4L)),
SdkBindingDataFactory.of(42L))
.withTaskOutput(
new DynamicFibonacciWorkflowTask(),
DynamicFibonacciWorkflowTask.Input.create(SdkBindingDataFactory.of(42L)),
DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(123L)))
.execute();

assertEquals(42L, result.getIntegerOutput("result"));
assertEquals(123L, result.getIntegerOutput("result"));
}

@Test
Expand Down Expand Up @@ -87,9 +91,12 @@ public void testMockSubWorkflow() {
new SumTask(),
SumInput.create(SdkBindingDataFactory.of(10L), SdkBindingDataFactory.of(4L)),
SdkBindingDataFactory.of(15L))
.withTask(
new DynamicFibonacciWorkflowTask(),
input -> DynamicFibonacciWorkflowTask.Output.create(SdkBindingDataFactory.of(42L)))
.execute();

assertEquals(15L, result.getIntegerOutput("result"));
assertEquals(42L, result.getIntegerOutput("result"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
import org.flyte.api.v1.WorkflowNode;
import org.flyte.api.v1.WorkflowNode.Reference;
import org.flyte.api.v1.WorkflowTemplate;
import org.flyte.flytekit.SdkDynamicWorkflowTask;
import org.flyte.flytekit.SdkRemoteLaunchPlan;
import org.flyte.flytekit.SdkRemoteTask;
import org.flyte.flytekit.SdkRunnableTask;
import org.flyte.flytekit.SdkTransform;
import org.flyte.flytekit.SdkType;
import org.flyte.flytekit.SdkWorkflow;
import org.flyte.localengine.ExecutionContext;
Expand Down Expand Up @@ -321,20 +323,27 @@ public <T> SdkTestingExecutor withFixedInputs(SdkType<T> type, T value) {

public <InputT, OutputT> SdkTestingExecutor withTaskOutput(
SdkRunnableTask<InputT, OutputT> task, InputT input, OutputT output) {
TestingRunnableTask<InputT, OutputT> fixedTask =
getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType());

return toBuilder()
.putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output))
.build();
return withTaskOutput0(task, input, output);
}

public <InputT, OutputT> SdkTestingExecutor withTaskOutput(
SdkRemoteTask<InputT, OutputT> task, InputT input, OutputT output) {
return withTaskOutput0(task, input, output);
}

public <InputT, OutputT> SdkTestingExecutor withTaskOutput(
SdkDynamicWorkflowTask<InputT, OutputT> task, InputT input, OutputT output) {
return withTaskOutput0(task, input, output);
}

private <InputT, OutputT> SdkTestingExecutor withTaskOutput0(
SdkTransform<InputT, OutputT> task, InputT input, OutputT output) {
TestingRunnableTask<InputT, OutputT> fixedTask =
getFixedTaskOrDefault(task.name(), task.inputs(), task.outputs());
getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType());

return toBuilder().putFixedTask(task.name(), fixedTask.withFixedOutput(input, output)).build();
return toBuilder()
.putFixedTask(task.getName(), fixedTask.withFixedOutput(input, output))
.build();
}

public <InputT, OutputT> SdkTestingExecutor withLaunchPlanOutput(
Expand All @@ -361,6 +370,16 @@ public <InputT, OutputT> SdkTestingExecutor withLaunchPlan(

public <InputT, OutputT> SdkTestingExecutor withTask(
SdkRunnableTask<InputT, OutputT> task, Function<InputT, OutputT> runFn) {
return withTask0(task, runFn);
}

public <InputT, OutputT> SdkTestingExecutor withTask(
SdkDynamicWorkflowTask<InputT, OutputT> task, Function<InputT, OutputT> runFn) {
return withTask0(task, runFn);
}

private <InputT, OutputT> SdkTestingExecutor withTask0(
SdkTransform<InputT, OutputT> task, Function<InputT, OutputT> runFn) {
TestingRunnableTask<InputT, OutputT> fixedTask =
getFixedTaskOrDefault(task.getName(), task.getInputType(), task.getOutputType());

Expand Down

0 comments on commit ac416d9

Please sign in to comment.