Skip to content

Commit

Permalink
Update _run_manager.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thorrester committed Oct 29, 2024
1 parent 23af934 commit aec037f
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions opsml/projects/_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import concurrent
import subprocess
import tempfile
import threading
Expand All @@ -25,6 +26,16 @@
logger = ArtifactLogger.get_logger()


class DaemonThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
def _adjust_thread_count(self) -> None:
if len(self._threads) < self._max_workers:
thread_name = f"ThreadPoolExecutor-{len(self._threads)}"
_thread = threading.Thread(name=thread_name, target=self._worker) # type: ignore
_thread.daemon = True
_thread.start()
self._threads.add(_thread)


def put_hw_metrics(
interval: int,
run: "ActiveRun",
Expand Down Expand Up @@ -108,7 +119,7 @@ def __init__(self, project_info: ProjectInfo, registries: CardRegistries):
self._project_info = project_info
self.active_run: Optional[ActiveRun] = None
self.registries = registries
self._hardware_futures: List[threading.Thread] = []
self._hardware_futures: List[Any] = []

run_id = project_info.run_id
if run_id is not None:
Expand All @@ -120,6 +131,16 @@ def __init__(self, project_info: ProjectInfo, registries: CardRegistries):
self.run_id = None
self._run_exists = False

self._thread_executor: Optional[DaemonThreadPoolExecutor] = None

@property
def thread_executor(self) -> Optional[concurrent.futures.ThreadPoolExecutor]:
return self._thread_executor

@thread_executor.setter
def thread_executor(self, value: concurrent.futures.ThreadPoolExecutor) -> None:
self._thread_executor = value

@property
def project_id(self) -> int:
assert self._project_info.project_id is not None, "project_id should not be None"
Expand Down Expand Up @@ -200,24 +221,17 @@ def _log_hardware_metrics(self, interval: int) -> None:

# run hardware logger in background thread
queue: Queue[Dict[str, Union[str, datetime, Dict[str, Any]]]] = Queue()
self.thread_executor = DaemonThreadPoolExecutor(max_workers=2)
assert self.thread_executor is not None, "thread_executor should not be None"

# submit futures for hardware logging
self._hardware_futures.append(
threading.Thread(
target=put_hw_metrics,
kwargs={"interval": interval, "run": self.active_run, "queue": queue},
)
self.thread_executor.submit(put_hw_metrics, interval, self.active_run, queue),
)
self._hardware_futures.append(
threading.Thread(
target=get_hw_metrics,
kwargs={"run": self.active_run, "queue": queue},
)
self.thread_executor.submit(get_hw_metrics, self.active_run, queue),
)

for future in self._hardware_futures:
future.daemon = True
future.start()

def _extract_code(
self,
filename: Path,
Expand Down Expand Up @@ -350,8 +364,15 @@ def end_run(self) -> None:
self._run_exists = False

# check if thread executor is still running
if self._hardware_futures:
if self.thread_executor is not None:
# cancel futures
for future in self._hardware_futures:
future.join()

future.cancel()
try:
future.result(timeout=1)
except Exception: # pylint: disable=broad-except
pass

self.thread_executor.shutdown(wait=False)
self.thread_executor = None
self._hardware_futures = []

0 comments on commit aec037f

Please sign in to comment.