From f456296f3300e249cc876c638f17dd2a828cd407 Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Fri, 20 Sep 2024 15:55:07 +0200 Subject: [PATCH] Refactor `_ert/forward_model_runner/job.py` run method --- src/_ert/forward_model_runner/job.py | 329 +++++++++++------- .../forward_model_runner/reporting/message.py | 12 +- src/_ert/forward_model_runner/runner.py | 2 - 3 files changed, 204 insertions(+), 139 deletions(-) diff --git a/src/_ert/forward_model_runner/job.py b/src/_ert/forward_model_runner/job.py index 5738d0c337f..9a3b664a13b 100644 --- a/src/_ert/forward_model_runner/job.py +++ b/src/_ert/forward_model_runner/job.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import io import json import logging import os @@ -10,7 +11,7 @@ from datetime import datetime as dt from pathlib import Path from subprocess import Popen, run -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple, cast from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess @@ -22,10 +23,13 @@ Start, ) +if TYPE_CHECKING: + from ert.config.forward_model_step import ForwardModelStepJSON + logger = logging.getLogger(__name__) -def killed_by_oom(pids: set[int]) -> bool: +def killed_by_oom(pids: Sequence[int]) -> bool: """Will try to detect if a process (or any of its descendants) was killed by the Linux OOM-killer. @@ -75,42 +79,48 @@ def killed_by_oom(pids: set[int]) -> bool: class Job: MEMORY_POLL_PERIOD = 5 # Seconds between memory polls - def __init__(self, job_data, index, sleep_interval=1): + def __init__( + self, job_data: ForwardModelStepJSON, index: int, sleep_interval: int = 1 + ) -> None: self.sleep_interval = sleep_interval self.job_data = job_data self.index = index self.std_err = job_data.get("stderr") self.std_out = job_data.get("stdout") - def run(self): + def run(self) -> Generator[Start | Exited | Running | None]: try: for msg in self._run(): yield msg except Exception as e: yield Exited(self, exit_code=1).with_error(str(e)) - def _run(self): + def create_start_message_and_check_job_files(self) -> Start: start_message = Start(self) errors = self._check_job_files() - errors.extend(self._assert_arg_list()) - self._dump_exec_env() if errors: - yield start_message.with_error("\n".join(errors)) - return - - yield start_message - - arg_list = [self.job_data.get("executable")] - if self.job_data.get("argList"): - arg_list += self.job_data["argList"] - - # stdin/stdout/stderr are closed at the end of this function + start_message = start_message.with_error("\n".join(errors)) + return start_message + + def _build_arg_list(self) -> List[str]: + executable = self.job_data.get("executable") + # assert executable is not None + combined_arg_list = [executable] + if arg_list := self.job_data.get("argList"): + combined_arg_list += arg_list + return combined_arg_list + + def _open_file_handles( + self, + ) -> Tuple[ + io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None + ]: if self.job_data.get("stdin"): - stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa + stdin = open(cast(Path, self.job_data.get("stdin")), encoding="utf-8") # noqa else: stdin = None @@ -132,25 +142,28 @@ def _run(self): else: stdout = None - target_file = self.job_data.get("target_file") - target_file_mtime: int = 0 - if target_file and os.path.exists(target_file): - stat = os.stat(target_file) - target_file_mtime = stat.st_mtime_ns + return (stdin, stdout, stderr) - max_running_minutes = self.job_data.get("max_running_minutes") - run_start_time = dt.now() - environment = self.job_data.get("environment") - if environment is not None: - environment = {**os.environ, **environment} - - def ensure_file_handles_closed(): - if stdin is not None: - stdin.close() - if stdout is not None: - stdout.close() - if stderr is not None: - stderr.close() + def _create_environment(self) -> Optional[Dict[str, str]]: + combined_environment = None + if environment := self.job_data.get("environment"): + combined_environment = {**os.environ, **environment} + return combined_environment + + def _run(self) -> Generator[Start | Exited | Running | None]: + start_message = self.create_start_message_and_check_job_files() + + yield start_message + if not start_message.success(): + return + + arg_list = self._build_arg_list() + + (stdin, stdout, stderr) = self._open_file_handles() + # stdin/stdout/stderr are closed at the end of this function + + target_file = self.job_data.get("target_file") + target_file_mtime: Optional[int] = _get_target_file_ntime(target_file) try: proc = Popen( @@ -158,30 +171,21 @@ def ensure_file_handles_closed(): stdin=stdin, stdout=stdout, stderr=stderr, - env=environment, + env=self._create_environment(), ) process = Process(proc.pid) except OSError as e: - msg = f"{e.strerror} {e.filename}" - if e.strerror == "Exec format error" and e.errno == 8: - msg = ( - f"Missing execution format information in file: {e.filename!r}." - f"Most likely you are missing and should add " - f"'#!/usr/bin/env python' to the top of the file: " - ) - if stderr: - stderr.write(msg) - ensure_file_handles_closed() - yield Exited(self, e.errno).with_error(msg) + exited_message = self._handle_process_io_error_and_create_exited_message( + e, stderr + ) + yield exited_message + ensure_file_handles_closed([stdin, stdout, stderr]) return exit_code = None - # All child pids for the forward model step. Need to track these in order to be able - # to detect OOM kills in case of failure. - fm_step_pids = {process.pid} - max_memory_usage = 0 + fm_step_pids = {int(process.pid)} while exit_code is None: (memory_rss, cpu_seconds, oom_score) = _get_processtree_data(process) max_memory_usage = max(memory_rss, max_memory_usage) @@ -200,106 +204,128 @@ def ensure_file_handles_closed(): try: exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD) except TimeoutExpired: + potential_exited_msg = ( + self.handle_process_timeout_and_create_exited_msg(exit_code, proc) + ) + if isinstance(potential_exited_msg, Exited): + yield potential_exited_msg + + return fm_step_pids |= { int(child.pid) for child in process.children(recursive=True) } - run_time = dt.now() - run_start_time - if ( - max_running_minutes is not None - and run_time.seconds > max_running_minutes * 60 - ): - # If the spawned process is not in the same process group as - # the callee (job_dispatch), we will kill the process group - # explicitly. - # - # Propagating the unsuccessful Exited message will kill the - # callee group. See job_dispatch.py. - process_group_id = os.getpgid(proc.pid) - this_group_id = os.getpgid(os.getpid()) - if process_group_id != this_group_id: - os.killpg(process_group_id, signal.SIGKILL) - - yield Exited(self, exit_code).with_error( - ( - f"Job:{self.name()} has been running " - f"for more than {max_running_minutes} " - "minutes - explicitly killed." - ) - ) - return - exited_message = Exited(self, exit_code) + ensure_file_handles_closed([stdin, stdout, stderr]) + exited_message = self._create_exited_message_based_on_exit_code( + max_memory_usage, target_file_mtime, exit_code, fm_step_pids + ) + yield exited_message + def _create_exited_message_based_on_exit_code( + self, + max_memory_usage: int, + target_file_mtime: Optional[int], + exit_code: int, + fm_step_pids: Sequence[int], + ) -> Exited: if exit_code != 0: - if killed_by_oom(fm_step_pids): - yield exited_message.with_error( - f"Forward model step {self.job_data.get('name')} " - "was killed due to out-of-memory. " - "Max memory usage recorded by Ert for the " - f"realization was {max_memory_usage//1024//1024} MB" - ) - else: - yield exited_message.with_error( - f"Process exited with status code {exit_code}" - ) - return - - # exit_code is 0 + exited_message = self._create_exited_msg_for_non_zero_exit_code( + max_memory_usage, exit_code, fm_step_pids + ) + return exited_message + exited_message = Exited(self, exit_code) if self.job_data.get("error_file") and os.path.exists( self.job_data["error_file"] ): - yield exited_message.with_error( + return exited_message.with_error( f'Found the error file:{self.job_data["error_file"]} - job failed.' ) - return - if target_file: + if target_file_mtime: target_file_error = self._check_target_file_is_written(target_file_mtime) if target_file_error: - yield exited_message.with_error(target_file_error) - return - ensure_file_handles_closed() - yield exited_message + return exited_message.with_error(target_file_error) - def _assert_arg_list(self): - errors = [] - if "arg_types" in self.job_data: - arg_types = self.job_data["arg_types"] - arg_list = self.job_data.get("argList") - for index, arg_type in enumerate(arg_types): - if arg_type == "RUNTIME_FILE": - file_path = os.path.join(os.getcwd(), arg_list[index]) - if not os.path.isfile(file_path): - errors.append( - f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} " - "does not exist." - ) - if arg_type == "RUNTIME_INT": - try: - int(arg_list[index]) - except ValueError: - errors.append( - ( - f"In job {self.name()}: argument with index {index} " - "is of incorrect type, should be integer." - ) - ) - return errors + return exited_message + + def _create_exited_msg_for_non_zero_exit_code( + self, + max_memory_usage: int, + exit_code: int, + fm_step_pids: Sequence[int], + ) -> Exited: + # All child pids for the forward model step. Need to track these in order to be able + # to detect OOM kills in case of failure. + exited_message = Exited(self, exit_code) - def name(self): + if killed_by_oom(fm_step_pids): + return exited_message.with_error( + f"Forward model step {self.job_data.get('name')} " + "was killed due to out-of-memory. " + "Max memory usage recorded by Ert for the " + f"realization was {max_memory_usage//1024//1024} MB" + ) + return exited_message.with_error( + f"Process exited with status code {exited_message.exit_code}" + ) + + def handle_process_timeout_and_create_exited_msg( + self, exit_code: Optional[int], proc: Popen[Process] + ) -> Exited | None: + max_running_minutes = self.job_data.get("max_running_minutes") + run_start_time = dt.now() + + run_time = dt.now() - run_start_time + if max_running_minutes is None or run_time.seconds > max_running_minutes * 60: + return None + + # If the spawned process is not in the same process group as + # the callee (job_dispatch), we will kill the process group + # explicitly. + # + # Propagating the unsuccessful Exited message will kill the + # callee group. See job_dispatch.py. + process_group_id = os.getpgid(proc.pid) + this_group_id = os.getpgid(os.getpid()) + if process_group_id != this_group_id: + os.killpg(process_group_id, signal.SIGKILL) + + return Exited(self, exit_code).with_error( + ( + f"Job:{self.name()} has been running " + f"for more than {max_running_minutes} " + "minutes - explicitly killed." + ) + ) + + def _handle_process_io_error_and_create_exited_message( + self, e: OSError, stderr: io.TextIOWrapper | None + ) -> Exited: + msg = f"{e.strerror} {e.filename}" + if e.strerror == "Exec format error" and e.errno == 8: + msg = ( + f"Missing execution format information in file: {e.filename!r}." + f"Most likely you are missing and should add " + f"'#!/usr/bin/env python' to the top of the file: " + ) + if stderr: + stderr.write(msg) + return Exited(self, e.errno).with_error(msg) + + def name(self) -> str: return self.job_data["name"] - def _dump_exec_env(self): + def _dump_exec_env(self) -> None: exec_env = self.job_data.get("exec_env") if exec_env: exec_name, _ = os.path.splitext( - os.path.basename(self.job_data.get("executable")) + os.path.basename(cast(Path, self.job_data.get("executable"))) ) with open(f"{exec_name}_exec_env.json", "w", encoding="utf-8") as f_handle: f_handle.write(json.dumps(exec_env, indent=4)) - def _check_job_files(self): + def _check_job_files(self) -> list[str]: """ Returns the empty list if no failed checks, or a list of errors in case of failed checks. @@ -309,21 +335,23 @@ def _check_job_files(self): errors.append(f'Could not locate stdin file: {self.job_data["stdin"]}') if self.job_data.get("start_file") and not os.path.exists( - self.job_data["start_file"] + cast(Path, self.job_data["start_file"]) ): errors.append(f'Could not locate start_file:{self.job_data["start_file"]}') if self.job_data.get("error_file") and os.path.exists( - self.job_data.get("error_file") + cast(Path, self.job_data.get("error_file")) ): - os.unlink(self.job_data.get("error_file")) + os.unlink(cast(Path, self.job_data.get("error_file"))) if executable_error := check_executable(self.job_data.get("executable")): - errors.append(str(executable_error)) + errors.append(executable_error) return errors - def _check_target_file_is_written(self, target_file_mtime: int, timeout=5): + def _check_target_file_is_written( + self, target_file_mtime: int, timeout: int = 5 + ) -> None | str: """ Check whether or not a target_file eventually appear. Returns None in case of success, an error message in the case of failure. @@ -356,6 +384,45 @@ def _check_target_file_is_written(self, target_file_mtime: int, timeout=5): ) return f"Could not find target_file:{target_file}" + def _assert_arg_list(self): + errors = [] + if "arg_types" in self.job_data: + arg_types = self.job_data["arg_types"] + arg_list = self.job_data.get("argList") + for index, arg_type in enumerate(arg_types): + if arg_type == "RUNTIME_FILE": + file_path = os.path.join(os.getcwd(), arg_list[index]) + if not os.path.isfile(file_path): + errors.append( + f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} " + "does not exist." + ) + if arg_type == "RUNTIME_INT": + try: + int(arg_list[index]) + except ValueError: + errors.append( + ( + f"In job {self.name()}: argument with index {index} " + "is of incorrect type, should be integer." + ) + ) + return errors + + +def _get_target_file_ntime(file: Optional[str]) -> Optional[int]: + mtime = None + if file and os.path.exists(file): + stat = os.stat(file) + mtime = stat.st_mtime_ns + return mtime + + +def ensure_file_handles_closed(file_handles: Sequence[io.TextIOWrapper | None]) -> None: + for file_handle in file_handles: + if file_handle is not None: + file_handle.close() + def _get_processtree_data( process: Process, diff --git a/src/_ert/forward_model_runner/reporting/message.py b/src/_ert/forward_model_runner/reporting/message.py index ace3862a02b..2811488da29 100644 --- a/src/_ert/forward_model_runner/reporting/message.py +++ b/src/_ert/forward_model_runner/reporting/message.py @@ -71,17 +71,17 @@ def __repr__(cls): class Message(metaclass=_MetaMessage): def __init__(self, job=None): self.timestamp = dt.now() - self.job = job - self.error_message = None + self.job: Optional[Job] = job + self.error_message: Optional[str] = None def __repr__(self): return type(self).__name__ - def with_error(self, message): + def with_error(self, message: str): self.error_message = message return self - def success(self): + def success(self) -> bool: return self.error_message is None @@ -116,7 +116,7 @@ def __init__(self): class Start(Message): - def __init__(self, job): + def __init__(self, job: "Job"): super().__init__(job) @@ -127,7 +127,7 @@ def __init__(self, job: "Job", memory_status: ProcessTreeStatus): class Exited(Message): - def __init__(self, job, exit_code): + def __init__(self, job, exit_code: int): super().__init__(job) self.exit_code = exit_code diff --git a/src/_ert/forward_model_runner/runner.py b/src/_ert/forward_model_runner/runner.py index 4b6b1e948aa..0892f4c02cf 100644 --- a/src/_ert/forward_model_runner/runner.py +++ b/src/_ert/forward_model_runner/runner.py @@ -18,7 +18,6 @@ def __init__(self, jobs_data): self.ert_pid = jobs_data.get("ert_pid") self.global_environment = jobs_data.get("global_environment") job_data_list = jobs_data["jobList"] - if self.simulation_id is not None: os.environ["ERT_RUN_ID"] = self.simulation_id @@ -79,7 +78,6 @@ def run(self, names_of_jobs_to_run): for job in job_queue: for status_update in job.run(): yield status_update - if not status_update.success(): yield Checksum(checksum_dict={}, run_path=os.getcwd()) yield Finish().with_error("Not all jobs completed successfully.")