diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index a5c9a23d9a1..e98d5d30e9b 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -337,6 +337,10 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str] """ return None + def local_execution_mode(self) -> ExecutionState.Mode: + """ """ + return ExecutionState.Mode.LOCAL_TASK_EXECUTION + def sandbox_execute( self, ctx: FlyteContext, @@ -602,7 +606,7 @@ def dispatch_execute( for k, v in native_outputs_as_map.items(): output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if ctx.execution_state and ctx.execution_state.is_local_execution(): # When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute _output_deck(self.name.split(".")[-1], new_user_params) diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 76553db702e..37c4afc88f1 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -4,7 +4,7 @@ import typing from typing import Optional, Tuple, Union, cast -from flytekit.core.context_manager import ExecutionState, FlyteContextManager +from flytekit.core.context_manager import FlyteContextManager from flytekit.core.node import Node from flytekit.core.promise import ( ComparisonExpression, @@ -488,7 +488,7 @@ def conditional(name: str) -> ConditionalSection: if ctx.compilation_state: return ConditionalSection(name) elif ctx.execution_state: - if ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if ctx.execution_state.is_local_execution(): # In case of Local workflow execution, we will actually evaluate the expression and based on the result # make the branch to be active using `take_branch` method from flytekit.core.context_manager import BranchEvalMode diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 2f40a8aa497..aa6b0e3e4d8 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -541,6 +541,12 @@ def with_params( user_space_params=user_space_params if user_space_params else self.user_space_params, ) + def is_local_execution(self): + return ( + self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION + or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION + ) + @dataclass(frozen=True) class FlyteContext(object): @@ -690,7 +696,7 @@ def enter_conditional_section(self) -> FlyteContext.Builder: self.compilation_state = self.compilation_state.with_params(prefix=self.compilation_state.prefix) if self.execution_state: - if self.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if self.execution_state.is_local_execution(): if self.in_a_condition: if self.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: self.execution_state = self.execution_state.with_params() diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index bc3ab1d3fd1..a09a8d82ceb 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -7,7 +7,7 @@ import click from flytekit.core import interface as flyte_interface -from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.user import FlyteDisapprovalException @@ -116,6 +116,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr else: raise FlyteDisapprovalException(f"User did not approve the transaction for gate node {self.name}") + def local_execution_mode(self): + return ExecutionState.Mode.LOCAL_TASK_EXECUTION + def wait_for_input(name: str, timeout: datetime.timedelta, expected_type: typing.Type): """Create a Gate object that waits for user input of the specified type. diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 522bd3f82c5..5a544bc316b 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -217,7 +217,7 @@ def _outputs_interface(self) -> Dict[Any, Variable]: """ ctx = FlyteContextManager.current_context() - if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if ctx.execution_state and ctx.execution_state.is_local_execution(): # In workflow execution mode we actually need to use the parent (mapper) task output interface. return self.interface.outputs return self._run_task.interface.outputs @@ -230,7 +230,7 @@ def get_type_for_output_var(self, k: str, v: Any) -> type: from these individual outputs as the final output value. """ ctx = FlyteContextManager.current_context() - if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if ctx.execution_state and ctx.execution_state.is_local_execution(): # In workflow execution mode we actually need to use the parent (mapper) task output interface. return self._python_interface.outputs[k] return self._run_task._python_interface.outputs[k] diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 62065f68696..c2de88599e7 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import BranchEvalMode, ExecutionState, FlyteContext +from flytekit.core.context_manager import BranchEvalMode, FlyteContext from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import VoidPromise @@ -144,10 +144,7 @@ def sub_wf(): # Handling local execution # Note: execution state is set to TASK_EXECUTION when running dynamic task locally # https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262 - elif ctx.execution_state is not None and ( - ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION - or ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION - ): + elif ctx.execution_state and ctx.execution_state.is_local_execution(): if isinstance(entity, RemoteEntity): raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 5dfd1a6b408..35b72b5a560 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -963,6 +963,9 @@ class LocallyExecutable(Protocol): def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: ... + def local_execution_mode(self) -> ExecutionState.Mode: + ... + def flyte_entity_call_handler( entity: SupportsNodeCreation, *args, **kwargs @@ -996,27 +999,38 @@ def flyte_entity_call_handler( ) ctx = FlyteContextManager.current_context() + if ctx.execution_state and ( + ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION + or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION + ): + logger.error("You are not supposed to nest @Task/@Workflow inside a @Task!") if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return create_and_link_node(ctx, entity=entity, **kwargs) - elif ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: - if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: + if ctx.execution_state and ctx.execution_state.is_local_execution(): + mode = cast(LocallyExecutable, entity).local_execution_mode() + with FlyteContextManager.with_context( + ctx.with_execution_state(ctx.execution_state.with_params(mode=mode)) + ) as child_ctx: if ( - len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0 - or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0 + child_ctx.execution_state + and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED ): - output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys()) - if len(output_names) == 0: - return VoidPromise(entity.name) - vals = [Promise(var, None) for var in output_names] - return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) - else: - return None - return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) + if ( + len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0 + or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0 + ): + output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys()) + if len(output_names) == 0: + return VoidPromise(entity.name) + vals = [Promise(var, None) for var in output_names] + return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) + else: + return None + return cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) else: + mode = cast(LocallyExecutable, entity).local_execution_mode() with FlyteContextManager.with_context( - ctx.with_execution_state( - ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION) - ) + ctx.with_execution_state(ctx.new_execution_state().with_params(mode=mode)) ) as child_ctx: cast(ExecutionParameters, child_ctx.user_space_params)._decks = [] result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 90b10cbc365..f1318941fa1 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -258,7 +258,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: representing that newly generated workflow, instead of executing it. """ ctx = FlyteContextManager.current_context() - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + if ctx.execution_state and ctx.execution_state.is_local_execution(): # The rest of this function mimics the local_execute of the workflow. We can't use the workflow # local_execute directly though since that converts inputs into Promises. logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index de386fa1595..0d861db513c 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -182,6 +182,9 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Pro vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) + def local_execution_mode(self): + return ExecutionState.Mode.LOCAL_TASK_EXECUTION + def construct_node_metadata(self) -> _workflow_model.NodeMetadata: return _workflow_model.NodeMetadata(name=extract_obj_name(self.name)) @@ -207,9 +210,7 @@ def __call__(self, *args, **kwargs): ctx = FlyteContext.current_context() if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: return self.compile(ctx, *args, **kwargs) - elif ( - ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION - ): + elif ctx.execution_state and ctx.execution_state.is_local_execution(): if ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED: return return self.local_execute(ctx, **kwargs) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 5689b4efbf1..7ced5940fc6 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -10,7 +10,13 @@ from flytekit.core.base_task import PythonTask from flytekit.core.class_based_resolver import ClassStorageTaskResolver from flytekit.core.condition import ConditionalSection -from flytekit.core.context_manager import CompilationState, FlyteContext, FlyteContextManager, FlyteEntities +from flytekit.core.context_manager import ( + CompilationState, + ExecutionState, + FlyteContext, + FlyteContextManager, + FlyteEntities, +) from flytekit.core.docstring import Docstring from flytekit.core.interface import ( Interface, @@ -334,6 +340,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr return create_task_output(new_promises, self.python_interface) + def local_execution_mode(self) -> ExecutionState.Mode: + """ """ + return ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION + class ImperativeWorkflow(WorkflowBase): """ diff --git a/flytekit/remote/remote_callable.py b/flytekit/remote/remote_callable.py index 9adfd4846fd..967a1ed49f9 100644 --- a/flytekit/remote/remote_callable.py +++ b/flytekit/remote/remote_callable.py @@ -65,6 +65,9 @@ def __call__(self, *args, **kwargs): def local_execute(self, ctx: FlyteContext, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: return self.execute(**kwargs) + def local_execution_mode(self) -> ExecutionState.Mode: + return ExecutionState.Mode.LOCAL_TASK_EXECUTION + def execute(self, **kwargs) -> Any: raise AssertionError(f"Remotely fetched entities cannot be run locally. Please mock the {self.name}.execute.")