diff --git a/parsl/config.py b/parsl/config.py index ecea149114..c3725eccf8 100644 --- a/parsl/config.py +++ b/parsl/config.py @@ -40,6 +40,15 @@ class Config(RepresentationMixin, UsageInformation): ``checkpoint_mode='periodic'``. dependency_resolver: plugin point for custom dependency resolvers. Default: only resolve Futures, using the `SHALLOW_DEPENDENCY_RESOLVER`. + exit_mode: str, optional + When Parsl is used as a context manager (using ``with parsl.load`` syntax) then this parameter + controls what will happen to running tasks and exceptions at exit. The options are: + + * ``cleanup``: cleanup the DFK on exit without waiting for any tasks + * ``skip``: skip all shutdown behaviour when exiting the context manager + * ``wait``: wait for all tasks to complete when exiting normally, but exit immediately when exiting due to an exception. + + Default is ``cleanup``. garbage_collect : bool. optional. Delete task records from DFK when tasks have completed. Default: True internal_tasks_max_threads : int, optional @@ -97,6 +106,7 @@ def __init__(self, Literal['manual']] = None, checkpoint_period: Optional[str] = None, dependency_resolver: Optional[DependencyResolver] = None, + exit_mode: Literal['cleanup', 'skip', 'wait'] = 'cleanup', garbage_collect: bool = True, internal_tasks_max_threads: int = 10, retries: int = 0, @@ -133,6 +143,7 @@ def __init__(self, checkpoint_period = "00:30:00" self.checkpoint_period = checkpoint_period self.dependency_resolver = dependency_resolver + self.exit_mode = exit_mode self.garbage_collect = garbage_collect self.internal_tasks_max_threads = internal_tasks_max_threads self.retries = retries diff --git a/parsl/dataflow/dflow.py b/parsl/dataflow/dflow.py index dffa7e52fd..86b429f3a0 100644 --- a/parsl/dataflow/dflow.py +++ b/parsl/dataflow/dflow.py @@ -217,9 +217,24 @@ def __init__(self, config: Config) -> None: def __enter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): - logger.debug("Exiting the context manager, calling cleanup for DFK") - self.cleanup() + def __exit__(self, exc_type, exc_value, traceback) -> None: + mode = self.config.exit_mode + logger.debug("Exiting context manager, with exit mode '%s'", mode) + if mode == "cleanup": + logger.info("Calling cleanup for DFK") + self.cleanup() + elif mode == "skip": + logger.info("Skipping all cleanup handling") + elif mode == "wait": + if exc_type is None: + logger.info("Waiting for all tasks to complete") + self.wait_for_current_tasks() + self.cleanup() + else: + logger.info("There was an exception - cleaning up without waiting for task completion") + self.cleanup() + else: + raise InternalConsistencyError(f"Exit case for {mode} should be unreachable, validated by typeguard on Config()") def _send_task_log_info(self, task_record: TaskRecord) -> None: if self.monitoring: diff --git a/parsl/tests/test_python_apps/test_context_manager.py b/parsl/tests/test_python_apps/test_context_manager.py index a314c0d362..6d3b020b16 100644 --- a/parsl/tests/test_python_apps/test_context_manager.py +++ b/parsl/tests/test_python_apps/test_context_manager.py @@ -1,7 +1,11 @@ +from concurrent.futures import Future +from threading import Event + import pytest import parsl -from parsl.dataflow.dflow import DataFlowKernel +from parsl.config import Config +from parsl.dataflow.dflow import DataFlowKernel, DataFlowKernelLoader from parsl.errors import NoDataFlowKernelError from parsl.tests.configs.local_threads import fresh_config @@ -16,6 +20,16 @@ def foo(x, stdout='foo.stdout'): return f"echo {x + 1}" +@parsl.python_app +def wait_for_event(ev: Event): + ev.wait() + + +@parsl.python_app +def raise_app(): + raise RuntimeError("raise_app deliberate failure") + + @pytest.mark.local def test_within_context_manger(tmpd_cwd): config = fresh_config() @@ -31,3 +45,84 @@ def test_within_context_manger(tmpd_cwd): with pytest.raises(NoDataFlowKernelError) as excinfo: square(2).result() assert str(excinfo.value) == "Must first load config" + + +@pytest.mark.local +def test_exit_skip(): + config = fresh_config() + config.exit_mode = "skip" + + with parsl.load(config) as dfk: + ev = Event() + fut = wait_for_event(ev) + # deliberately don't wait for this to finish, so that the context + # manager can exit + + assert parsl.dfk() is dfk, "global dfk should be left in place by skip mode" + + assert not fut.done(), "wait_for_event should not be done yet" + ev.set() + + # now we can wait for that result... + fut.result() + assert fut.done(), "wait_for_event should complete outside of context manager in 'skip' mode" + + # now cleanup the DFK that the above `with` block + # deliberately avoided doing... + dfk.cleanup() + + +# 'wait' mode has two cases to test: +# 1. that we wait when there is no exception +# 2. that we do not wait when there is an exception +@pytest.mark.local +def test_exit_wait_no_exception(): + config = fresh_config() + config.exit_mode = "wait" + + with parsl.load(config) as dfk: + fut = square(1) + # deliberately don't wait for this to finish, so that the context + # manager can exit + + assert fut.done(), "This future should be marked as done before the context manager exits" + + assert dfk.cleanup_called, "The DFK should have been cleaned up by the context manager" + assert DataFlowKernelLoader._dfk is None, "The global DFK should have been removed" + + +@pytest.mark.local +def test_exit_wait_exception(): + config = fresh_config() + config.exit_mode = "wait" + + with pytest.raises(RuntimeError): + with parsl.load(config) as dfk: + # we'll never fire this future + fut_never = Future() + + fut_raise = raise_app() + + fut_depend = square(fut_never) + + # this should cause an exception, which should cause the context + # manager to exit, without waiting for fut_depend to finish. + fut_raise.result() + + assert dfk.cleanup_called, "The DFK should have been cleaned up by the context manager" + assert DataFlowKernelLoader._dfk is None, "The global DFK should have been removed" + assert fut_raise.exception() is not None, "fut_raise should contain an exception" + assert not fut_depend.done(), "fut_depend should have been left un-done (due to dependency failure)" + + +@pytest.mark.local +def test_exit_wrong_mode(): + + with pytest.raises(Exception) as ex: + Config(exit_mode="wrongmode") + + # with typeguard 4.x this is TypeCheckError, + # with typeguard 2.x this is TypeError + # we can't instantiate TypeCheckError if we're in typeguard 2.x environment + # because it does not exist... so check name using strings. + assert ex.type.__name__ == "TypeCheckError" or ex.type.__name__ == "TypeError"