diff --git a/python/morpheus/morpheus/pipeline/pipeline.py b/python/morpheus/morpheus/pipeline/pipeline.py index 117a57bbe6..6f719e4d54 100644 --- a/python/morpheus/morpheus/pipeline/pipeline.py +++ b/python/morpheus/morpheus/pipeline/pipeline.py @@ -19,16 +19,18 @@ import sys import threading import typing -from collections import OrderedDict, defaultdict +from collections import OrderedDict +from collections import defaultdict from enum import Enum from functools import partial -import morpheus.pipeline as _pipeline # pylint: disable=cyclic-import import mrc import networkx +from tqdm import tqdm + +import morpheus.pipeline as _pipeline # pylint: disable=cyclic-import from morpheus.config import Config from morpheus.utils.type_utils import pretty_print_type_name -from tqdm import tqdm logger = logging.getLogger(__name__) @@ -89,8 +91,6 @@ def __init__(self, config: Config): # Future that allows post_start to propagate exceptions back to pipeline self._post_start_future: asyncio.Future = None - self.__exit_count = 0 - @property def state(self) -> PipelineState: return self._state @@ -302,12 +302,7 @@ def build(self): exec_options.topology.user_cpuset = f"0-{self._num_threads - 1}" exec_options.engine_factories.default_engine_type = mrc.core.options.EngineType.Thread - def state_change_handler(state: mrc.State): - logger.debug("MRC Executor State change: %s", state) - if ((state == mrc.State.Stop and self.__exit_count == 0) or state == mrc.State.Kill): - self._shutdown_handler("MRC Executor stopped. Stopping pipeline... Press Ctrl+C to kill.") - - self._mrc_executor = mrc.Executor(exec_options, state_change_handler) + self._mrc_executor = mrc.Executor(exec_options) mrc_pipeline = mrc.Pipeline() @@ -370,19 +365,32 @@ async def _start(self): self._loop = asyncio.get_running_loop() # Setup error handling and cancellation of the pipeline - def exception_handler(_, context: dict): + def error_handler(_, context: dict): msg = f"Unhandled exception in async loop! Exception: \n{context['message']}" exception = context.get("exception", Exception()) logger.critical(msg, exc_info=exception) - self._loop.set_exception_handler(exception_handler) + self._loop.set_exception_handler(error_handler) + + exit_count = 0 # Handles Ctrl+C for graceful shutdown - shutdown_msg="Stopping pipeline. Please wait... Press Ctrl+C again to kill." + def term_signal(): + + nonlocal exit_count + exit_count = exit_count + 1 + + if (exit_count == 1): + tqdm.write("Stopping pipeline. Please wait... Press Ctrl+C again to kill.") + self.stop() + else: + tqdm.write("Killing") + sys.exit(1) + for sig in [signal.SIGINT, signal.SIGTERM]: - self._loop.add_signal_handler(sig, partial(self._shutdown_handler, shutdown_msg=shutdown_msg)) + self._loop.add_signal_handler(sig, term_signal) logger.info("====Starting Pipeline====") @@ -432,16 +440,6 @@ def stop(self): logger.info("====Pipeline Stopped====") self._on_stop() - def _shutdown_handler(self, shutdown_msg: str): - self.__exit_count += 1 - - if (self.__exit_count == 1): - tqdm.write(shutdown_msg) - self.stop() - else: - tqdm.write("Killing") - sys.exit(1) - async def join(self): """ Wait until pipeline completes upon which join methods of sources and stages will be called.