From aec037f1dafaed2d46ca72148caf0c0fd242351f Mon Sep 17 00:00:00 2001 From: Thorrester Date: Tue, 29 Oct 2024 16:07:39 -0400 Subject: [PATCH] Update _run_manager.py --- opsml/projects/_run_manager.py | 53 ++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/opsml/projects/_run_manager.py b/opsml/projects/_run_manager.py index b20e3b8e4..141c0e767 100644 --- a/opsml/projects/_run_manager.py +++ b/opsml/projects/_run_manager.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import concurrent import subprocess import tempfile import threading @@ -25,6 +26,16 @@ logger = ArtifactLogger.get_logger() +class DaemonThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): + def _adjust_thread_count(self) -> None: + if len(self._threads) < self._max_workers: + thread_name = f"ThreadPoolExecutor-{len(self._threads)}" + _thread = threading.Thread(name=thread_name, target=self._worker) # type: ignore + _thread.daemon = True + _thread.start() + self._threads.add(_thread) + + def put_hw_metrics( interval: int, run: "ActiveRun", @@ -108,7 +119,7 @@ def __init__(self, project_info: ProjectInfo, registries: CardRegistries): self._project_info = project_info self.active_run: Optional[ActiveRun] = None self.registries = registries - self._hardware_futures: List[threading.Thread] = [] + self._hardware_futures: List[Any] = [] run_id = project_info.run_id if run_id is not None: @@ -120,6 +131,16 @@ def __init__(self, project_info: ProjectInfo, registries: CardRegistries): self.run_id = None self._run_exists = False + self._thread_executor: Optional[DaemonThreadPoolExecutor] = None + + @property + def thread_executor(self) -> Optional[concurrent.futures.ThreadPoolExecutor]: + return self._thread_executor + + @thread_executor.setter + def thread_executor(self, value: concurrent.futures.ThreadPoolExecutor) -> None: + self._thread_executor = value + @property def project_id(self) -> int: assert self._project_info.project_id is not None, "project_id should not be None" @@ -200,24 +221,17 @@ def _log_hardware_metrics(self, interval: int) -> None: # run hardware logger in background thread queue: Queue[Dict[str, Union[str, datetime, Dict[str, Any]]]] = Queue() + self.thread_executor = DaemonThreadPoolExecutor(max_workers=2) + assert self.thread_executor is not None, "thread_executor should not be None" + # submit futures for hardware logging self._hardware_futures.append( - threading.Thread( - target=put_hw_metrics, - kwargs={"interval": interval, "run": self.active_run, "queue": queue}, - ) + self.thread_executor.submit(put_hw_metrics, interval, self.active_run, queue), ) self._hardware_futures.append( - threading.Thread( - target=get_hw_metrics, - kwargs={"run": self.active_run, "queue": queue}, - ) + self.thread_executor.submit(get_hw_metrics, self.active_run, queue), ) - for future in self._hardware_futures: - future.daemon = True - future.start() - def _extract_code( self, filename: Path, @@ -350,8 +364,15 @@ def end_run(self) -> None: self._run_exists = False # check if thread executor is still running - if self._hardware_futures: + if self.thread_executor is not None: + # cancel futures for future in self._hardware_futures: - future.join() - + future.cancel() + try: + future.result(timeout=1) + except Exception: # pylint: disable=broad-except + pass + + self.thread_executor.shutdown(wait=False) + self.thread_executor = None self._hardware_futures = []