Skip to content

Commit

Permalink
Refactor forward_model_runner/job.py run method
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Sep 25, 2024
1 parent c331fe9 commit 1aac6e7
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 98 deletions.
254 changes: 160 additions & 94 deletions src/_ert/forward_model_runner/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import io
import json
import logging
import os
Expand All @@ -10,7 +11,7 @@
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import Optional, Tuple
from typing import Dict, Generator, Optional, Sequence, Tuple

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess

Expand Down Expand Up @@ -75,21 +76,21 @@ def killed_by_oom(pids: set[int]) -> bool:
class Job:
MEMORY_POLL_PERIOD = 5 # Seconds between memory polls

def __init__(self, job_data, index, sleep_interval=1):
def __init__(self, job_data, index, sleep_interval=1) -> None:
self.sleep_interval = sleep_interval
self.job_data = job_data
self.job_data: Dict[str, str] = job_data
self.index = index
self.std_err = job_data.get("stderr")
self.std_out = job_data.get("stdout")

def run(self):
def run(self) -> Generator[Start | Exited | Running]:
try:
for msg in self._run():
yield msg
except Exception as e:
yield Exited(self, exit_code=1).with_error(str(e))

def _run(self):
def create_start_message_and_check_job_files(self) -> Start:
start_message = Start(self)

errors = self._check_job_files()
Expand All @@ -99,16 +100,22 @@ def _run(self):
self._dump_exec_env()

if errors:
yield start_message.with_error("\n".join(errors))
return

yield start_message

arg_list = [self.job_data.get("executable")]
if self.job_data.get("argList"):
arg_list += self.job_data["argList"]

# stdin/stdout/stderr are closed at the end of this function
start_message = start_message.with_error("\n".join(errors))
return start_message

def _build_arg_list(self):
executable = self.job_data.get("executable")

combined_arg_list = [executable]
if arg_list := self.job_data.get("argList"):
combined_arg_list += arg_list
return combined_arg_list

def _open_file_handles(
self,
) -> Tuple[
io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None
]:
if self.job_data.get("stdin"):
stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa
else:
Expand All @@ -132,56 +139,50 @@ def _run(self):
else:
stdout = None

target_file = self.job_data.get("target_file")
target_file_mtime: int = 0
if target_file and os.path.exists(target_file):
stat = os.stat(target_file)
target_file_mtime = stat.st_mtime_ns
return (stdin, stdout, stderr)

max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()
def _create_environment(self) -> Dict:
environment = self.job_data.get("environment")
if environment is not None:
environment = {**os.environ, **environment}
return environment

def ensure_file_handles_closed():
if stdin is not None:
stdin.close()
if stdout is not None:
stdout.close()
if stderr is not None:
stderr.close()
def _run(self) -> contextlib.Generator[Start | Exited | Running]:
start_message = self.create_start_message_and_check_job_files()

yield start_message
if not start_message.success():
return

arg_list = self._build_arg_list()

(stdin, stdout, stderr) = self._open_file_handles()
# stdin/stdout/stderr are closed at the end of this function

target_file = self.job_data.get("target_file")
target_file_mtime: int = _get_target_file_ntime(target_file)

try:
proc = Popen(
arg_list,
stdin=stdin,
stdout=stdout,
stderr=stderr,
env=environment,
env=self._create_environment(),
)
process = Process(proc.pid)
except OSError as e:
msg = f"{e.strerror} {e.filename}"
if e.strerror == "Exec format error" and e.errno == 8:
msg = (
f"Missing execution format information in file: {e.filename!r}."
f"Most likely you are missing and should add "
f"'#!/usr/bin/env python' to the top of the file: "
)
if stderr:
stderr.write(msg)
ensure_file_handles_closed()
yield Exited(self, e.errno).with_error(msg)
exited_message = self._handle_process_io_error_and_create_exited_message(
e, stderr
)
ensure_file_handles_closed([stdin, stdout, stderr])
yield exited_message
return

exit_code = None

# All child pids for the forward model step. Need to track these in order to be able
# to detect OOM kills in case of failure.
fm_step_pids = {process.pid}

max_memory_usage = 0
fm_step_pids = {int(process.pid)}
while exit_code is None:
(memory_rss, cpu_seconds, oom_score) = _get_processtree_data(process)
max_memory_usage = max(memory_rss, max_memory_usage)
Expand All @@ -200,70 +201,121 @@ def ensure_file_handles_closed():
try:
exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD)
except TimeoutExpired:
exited_msg = self.handle_process_timeout_and_create_exited_msg(
process, proc
)
fm_step_pids |= {
int(child.pid) for child in process.children(recursive=True)
}
run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

yield Exited(self, exit_code).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)
if isinstance(exited_msg, Exited):
yield exited_msg
return

exited_message = Exited(self, exit_code)
ensure_file_handles_closed([stdin, stdout, stderr])
exited_message = self._create_exited_message_based_on_exit_code(
max_memory_usage, target_file_mtime, exit_code, fm_step_pids
)
yield exited_message

def _create_exited_message_based_on_exit_code(
self,
max_memory_usage: int,
target_file_mtime: Optional[int],
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# exit_code = proc.returncode

if exit_code != 0:
if killed_by_oom(fm_step_pids):
yield exited_message.with_error(
f"Forward model step {self.job_data.get('name')} "
"was killed due to out-of-memory. "
"Max memory usage recorded by Ert for the "
f"realization was {max_memory_usage//1024//1024} MB"
)
else:
yield exited_message.with_error(
f"Process exited with status code {exit_code}"
)
return
exited_message = self._create_exited_msg_for_non_zero_exit_code(
max_memory_usage, exit_code, fm_step_pids
)
return exited_message

# exit_code is 0

if self.job_data.get("error_file") and os.path.exists(
self.job_data["error_file"]
):
yield exited_message.with_error(
exited_message = Exited(self, exit_code)
return exited_message.with_error(
f'Found the error file:{self.job_data["error_file"]} - job failed.'
)
return

if target_file:
if target_file_mtime:
target_file_error = self._check_target_file_is_written(target_file_mtime)
if target_file_error:
yield exited_message.with_error(target_file_error)
return
ensure_file_handles_closed()
yield exited_message
return exited_message.with_error(target_file_error)

def _assert_arg_list(self):
errors = []
return Exited(self, exit_code)

def _create_exited_msg_for_non_zero_exit_code(
self,
max_memory_usage: int,
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# All child pids for the forward model step. Need to track these in order to be able
# to detect OOM kills in case of failure.
exited_message = Exited(self, exit_code)

if killed_by_oom(fm_step_pids):
return exited_message.with_error(
f"Forward model step {self.job_data.get('name')} "
"was killed due to out-of-memory. "
"Max memory usage recorded by Ert for the "
f"realization was {max_memory_usage//1024//1024} MB"
)
return exited_message.with_error(
f"Process exited with status code {exited_message.exit_code}"
)

def handle_process_timeout_and_create_exited_msg(
self, process: Process, proc: Popen
) -> Exited | None:
max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()

run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)

def _handle_process_io_error_and_create_exited_message(
self, e: OSError, stderr: io.TextIOWrapper | None
) -> Exited:
msg = f"{e.strerror} {e.filename}"
if e.strerror == "Exec format error" and e.errno == 8:
msg = (
f"Missing execution format information in file: {e.filename!r}."
f"Most likely you are missing and should add "
f"'#!/usr/bin/env python' to the top of the file: "
)
if stderr:
stderr.write(msg)
return Exited(self, e.errno).with_error(msg)

def _assert_arg_list(self) -> list[str]:
errors: list[str] = []
if "arg_types" in self.job_data:
arg_types = self.job_data["arg_types"]
arg_list = self.job_data.get("argList")
Expand All @@ -287,10 +339,10 @@ def _assert_arg_list(self):
)
return errors

def name(self):
def name(self) -> str:
return self.job_data["name"]

def _dump_exec_env(self):
def _dump_exec_env(self) -> None:
exec_env = self.job_data.get("exec_env")
if exec_env:
exec_name, _ = os.path.splitext(
Expand All @@ -299,7 +351,7 @@ def _dump_exec_env(self):
with open(f"{exec_name}_exec_env.json", "w", encoding="utf-8") as f_handle:
f_handle.write(json.dumps(exec_env, indent=4))

def _check_job_files(self):
def _check_job_files(self)-> list[str]:
"""
Returns the empty list if no failed checks, or a list of errors in case
of failed checks.
Expand All @@ -323,7 +375,7 @@ def _check_job_files(self):

return errors

def _check_target_file_is_written(self, target_file_mtime: int, timeout=5):
def _check_target_file_is_written(self, target_file_mtime: int, timeout: int =5) -> None | str:
"""
Check whether or not a target_file eventually appear. Returns None in
case of success, an error message in the case of failure.
Expand Down Expand Up @@ -357,6 +409,20 @@ def _check_target_file_is_written(self, target_file_mtime: int, timeout=5):
return f"Could not find target_file:{target_file}"


def _get_target_file_ntime(file: Optional[str]) -> Optional[int]:
mtime = None
if file and os.path.exists(file):
stat = os.stat(file)
mtime = stat.st_mtime_ns
return mtime


def ensure_file_handles_closed(file_handles: Sequence[io.TextIOWrapper | None]) -> None:
for file_handle in file_handles:
if file_handle is not None:
file_handle.close()


def _get_processtree_data(
process: Process,
) -> Tuple[int, float, Optional[int]]:
Expand Down
6 changes: 3 additions & 3 deletions src/_ert/forward_model_runner/reporting/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def __repr__(cls):
class Message(metaclass=_MetaMessage):
def __init__(self, job=None):
self.timestamp = dt.now()
self.job = job
self.error_message = None
self.job: Optional[Job] = job
self.error_message: Optional[str] = None

def __repr__(self):
return type(self).__name__

def with_error(self, message):
def with_error(self, message: str):
self.error_message = message
return self

Expand Down
Loading

0 comments on commit 1aac6e7

Please sign in to comment.