diff --git a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java index b46c4489..9c4c8729 100644 --- a/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java +++ b/jflyte/src/main/java/org/flyte/jflyte/ExecuteDynamicWorkflow.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -41,6 +42,7 @@ import org.flyte.api.v1.DynamicJobSpec; import org.flyte.api.v1.DynamicWorkflowTask; import org.flyte.api.v1.DynamicWorkflowTaskRegistrar; +import org.flyte.api.v1.IfBlock; import org.flyte.api.v1.Literal; import org.flyte.api.v1.NamedEntityIdentifier; import org.flyte.api.v1.Node; @@ -280,6 +282,28 @@ private static List collectAllUsedTaskTemplates( flyteAdminClient, cache); + // collect task templates used by conditionals + spec.nodes().stream() + .filter(node -> node.branchNode() != null) + .forEach( + node -> { + List nodes = new ArrayList<>(); + nodes.add(node.branchNode().ifElse().case_().thenNode()); + nodes.add(node.branchNode().ifElse().elseNode()); + nodes.addAll( + node.branchNode().ifElse().other().stream() + .map(IfBlock::thenNode) + .collect(toList())); + + collectTaskTemplates( + nodes, + nodesRewriter, + allUsedTaskTemplates, + allTaskTemplates, + flyteAdminClient, + cache); + }); + // collect task templates used by subworkflows allUsedSubWorkflows .values()