Skip to content

Commit

Permalink
Throw warning for nested @task functions (#1727)
Browse files Browse the repository at this point in the history
* Throw warning for nested @task functions

Signed-off-by: oliverhu <[email protected]>

* Update flytekit/remote/remote_callable.py

Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: oliverhu <[email protected]>

* also update execution state for default case

Signed-off-by: oliverhu <[email protected]>

* fix linting

Signed-off-by: oliverhu <[email protected]>

---------

Signed-off-by: oliverhu <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
  • Loading branch information
2 people authored and Fabio Grätz committed Aug 14, 2023
1 parent 09175ad commit a56f050
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 32 deletions.
6 changes: 5 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def sandbox_execute(
self,
ctx: FlyteContext,
Expand Down Expand Up @@ -602,7 +606,7 @@ def dispatch_execute(
for k, v in native_outputs_as_map.items():
output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v)))

if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute
_output_deck(self.name.split(".")[-1], new_user_params)

Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from typing import Optional, Tuple, Union, cast

from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.node import Node
from flytekit.core.promise import (
ComparisonExpression,
Expand Down Expand Up @@ -488,7 +488,7 @@ def conditional(name: str) -> ConditionalSection:
if ctx.compilation_state:
return ConditionalSection(name)
elif ctx.execution_state:
if ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state.is_local_execution():
# In case of Local workflow execution, we will actually evaluate the expression and based on the result
# make the branch to be active using `take_branch` method
from flytekit.core.context_manager import BranchEvalMode
Expand Down
8 changes: 7 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,12 @@ def with_params(
user_space_params=user_space_params if user_space_params else self.user_space_params,
)

def is_local_execution(self):
return (
self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
)


@dataclass(frozen=True)
class FlyteContext(object):
Expand Down Expand Up @@ -690,7 +696,7 @@ def enter_conditional_section(self) -> FlyteContext.Builder:
self.compilation_state = self.compilation_state.with_params(prefix=self.compilation_state.prefix)

if self.execution_state:
if self.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if self.execution_state.is_local_execution():
if self.in_a_condition:
if self.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
self.execution_state = self.execution_state.with_params()
Expand Down
5 changes: 4 additions & 1 deletion flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import click

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.user import FlyteDisapprovalException
Expand Down Expand Up @@ -116,6 +116,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
raise FlyteDisapprovalException(f"User did not approve the transaction for gate node {self.name}")

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION


def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type):
"""Create a Gate object that waits for user input of the specified type.
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]:
"""

ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self.interface.outputs
return self._run_task.interface.outputs
Expand All @@ -230,7 +230,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> type:
from these individual outputs as the final output value.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# In workflow execution mode we actually need to use the parent (mapper) task output interface.
return self._python_interface.outputs[k]
return self._run_task._python_interface.outputs[k]
Expand Down
7 changes: 2 additions & 5 deletions flytekit/core/node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Union

from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext
from flytekit.core.context_manager import BranchEvalMode, FlyteContext
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import VoidPromise
Expand Down Expand Up @@ -144,10 +144,7 @@ def sub_wf():
# Handling local execution
# Note: execution state is set to TASK_EXECUTION when running dynamic task locally
# https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262
elif ctx.execution_state is not None and (
ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
or ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
):
elif ctx.execution_state and ctx.execution_state.is_local_execution():
if isinstance(entity, RemoteEntity):
raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}")

Expand Down
44 changes: 29 additions & 15 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,9 @@ class LocallyExecutable(Protocol):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]:
...

def local_execution_mode(self) -> ExecutionState.Mode:
...


def flyte_entity_call_handler(
entity: SupportsNodeCreation, *args, **kwargs
Expand Down Expand Up @@ -996,27 +999,38 @@ def flyte_entity_call_handler(
)

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION
):
logger.error("You are not supposed to nest @Task/@Workflow inside a @Task!")
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return create_and_link_node(ctx, entity=entity, **kwargs)
elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
if ctx.execution_state and ctx.execution_state.is_local_execution():
mode = cast(LocallyExecutable, entity).local_execution_mode()
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(mode=mode))
) as child_ctx:
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)
else:
mode = cast(LocallyExecutable, entity).local_execution_mode()
with FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION)
)
ctx.with_execution_state(ctx.new_execution_state().with_params(mode=mode))
) as child_ctx:
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any:
representing that newly generated workflow, instead of executing it.
"""
ctx = FlyteContextManager.current_context()
if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION:
if ctx.execution_state and ctx.execution_state.is_local_execution():
# The rest of this function mimics the local_execute of the workflow. We can't use the workflow
# local_execute directly though since that converts inputs into Promises.
logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}")
Expand Down
7 changes: 4 additions & 3 deletions flytekit/core/reference_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Pro
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def local_execution_mode(self):
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def construct_node_metadata(self) -> _workflow_model.NodeMetadata:
return _workflow_model.NodeMetadata(name=extract_obj_name(self.name))

Expand All @@ -207,9 +210,7 @@ def __call__(self, *args, **kwargs):
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
return self.compile(ctx, *args, **kwargs)
elif (
ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION
):
elif ctx.execution_state and ctx.execution_state.is_local_execution():
if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED:
return
return self.local_execute(ctx, **kwargs)
Expand Down
12 changes: 11 additions & 1 deletion flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from flytekit.core.base_task import PythonTask
from flytekit.core.class_based_resolver import ClassStorageTaskResolver
from flytekit.core.condition import ConditionalSection
from flytekit.core.context_manager import CompilationState, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.context_manager import (
CompilationState,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.docstring import Docstring
from flytekit.core.interface import (
Interface,
Expand Down Expand Up @@ -334,6 +340,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr

return create_task_output(new_promises, self.python_interface)

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION


class ImperativeWorkflow(WorkflowBase):
"""
Expand Down
3 changes: 3 additions & 0 deletions flytekit/remote/remote_callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __call__(self, *args, **kwargs):
def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]:
return self.execute(**kwargs)

def local_execution_mode(self) -> ExecutionState.Mode:
return ExecutionState.Mode.LOCAL_TASK_EXECUTION

def execute(self, **kwargs) -> Any:
raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.")

Expand Down

0 comments on commit a56f050

Please sign in to comment.