Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Oct 3, 2024
1 parent 35fcc8a commit b688ca8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 81 deletions.
132 changes: 54 additions & 78 deletions src/_ert/forward_model_runner/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from datetime import datetime as dt
from pathlib import Path
from subprocess import Popen, run
from typing import Dict, Generator, Optional, Sequence, Tuple
from typing import Dict, Generator, List, Optional, Sequence, Tuple, cast

from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess

from ert.config.forward_model_step import ForwardModelStepJSON

from .io import check_executable
from .reporting.message import (
Exited,
Expand All @@ -26,7 +28,7 @@
logger = logging.getLogger(__name__)


def killed_by_oom(pids: set[int]) -> bool:
def killed_by_oom(pids: Sequence[int]) -> bool:
"""Will try to detect if a process (or any of its descendants) was killed
by the Linux OOM-killer.
Expand Down Expand Up @@ -76,36 +78,38 @@ def killed_by_oom(pids: set[int]) -> bool:
class Job:
MEMORY_POLL_PERIOD = 5 # Seconds between memory polls

def __init__(self, job_data, index, sleep_interval=1) -> None:
def __init__(
self, job_data: ForwardModelStepJSON, index: int, sleep_interval: int = 1
) -> None:
self.sleep_interval = sleep_interval
self.job_data: Dict[str, str] = job_data
self.job_data = job_data
self.index = index
self.std_err = job_data.get("stderr")
self.std_out = job_data.get("stdout")

def run(self) -> Generator[Start | Exited | Running]:
def run(self) -> Generator[Start | Exited | Running | None]:
try:
for msg in self._run():
yield msg
except StopIteration as e:
raise e
except Exception as e:
yield Exited(self, exit_code=1).with_error(str(e))

def create_start_message_and_check_job_files(self) -> Start:
start_message = Start(self)

errors = self._check_job_files()

errors.extend(self._assert_arg_list())
errors = [*self._check_job_files()]

self._dump_exec_env()

if errors:
start_message = start_message.with_error("\n".join(errors))
return start_message

def _build_arg_list(self):
def _build_arg_list(self) -> List[str]:
executable = self.job_data.get("executable")

# assert executable is not None
combined_arg_list = [executable]
if arg_list := self.job_data.get("argList"):
combined_arg_list += arg_list
Expand All @@ -117,7 +121,7 @@ def _open_file_handles(
io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None
]:
if self.job_data.get("stdin"):
stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa
stdin = open(cast(Path, self.job_data.get("stdin")), encoding="utf-8") # noqa
else:
stdin = None

Expand All @@ -141,26 +145,26 @@ def _open_file_handles(

return (stdin, stdout, stderr)

def _create_environment(self) -> Dict:
environment = self.job_data.get("environment")
if environment is not None:
environment = {**os.environ, **environment}
return environment
def _create_environment(self) -> Dict[str, str]:
combined_environment = {}
if environment := self.job_data.get("environment"):
combined_environment = {**os.environ, **environment}
return combined_environment

def _run(self) -> contextlib.Generator[Start | Exited | Running]:
def _run(self) -> Generator[Start | Exited | Running | None]:
start_message = self.create_start_message_and_check_job_files()

yield start_message
if not start_message.success():
return
raise StopIteration()

arg_list = self._build_arg_list()

(stdin, stdout, stderr) = self._open_file_handles()
# stdin/stdout/stderr are closed at the end of this function

target_file = self.job_data.get("target_file")
target_file_mtime: int = _get_target_file_ntime(target_file)
target_file_mtime: Optional[int] = _get_target_file_ntime(target_file)

try:
proc = Popen(
Expand All @@ -177,7 +181,7 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]:
)
ensure_file_handles_closed([stdin, stdout, stderr])
yield exited_message
return
raise StopIteration() from None

exit_code = None

Expand Down Expand Up @@ -209,7 +213,7 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]:
}
if isinstance(exited_msg, Exited):
yield exited_msg
return
raise StopIteration() from None

ensure_file_handles_closed([stdin, stdout, stderr])
exited_message = self._create_exited_message_based_on_exit_code(
Expand All @@ -224,16 +228,12 @@ def _create_exited_message_based_on_exit_code(
exit_code: int,
fm_step_pids: Sequence[int],
) -> Exited:
# exit_code = proc.returncode

if exit_code != 0:
exited_message = self._create_exited_msg_for_non_zero_exit_code(
max_memory_usage, exit_code, fm_step_pids
)
return exited_message

# exit_code is 0

if self.job_data.get("error_file") and os.path.exists(
self.job_data["error_file"]
):
Expand Down Expand Up @@ -271,34 +271,33 @@ def _create_exited_msg_for_non_zero_exit_code(
)

def handle_process_timeout_and_create_exited_msg(
self, process: Process, proc: Popen
self, process: Process, proc: Popen[Process]
) -> Exited | None:
max_running_minutes = self.job_data.get("max_running_minutes")
run_start_time = dt.now()

run_time = dt.now() - run_start_time
if (
max_running_minutes is not None
and run_time.seconds > max_running_minutes * 60
):
# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
if max_running_minutes is None or run_time.seconds > max_running_minutes * 60:
return None

# If the spawned process is not in the same process group as
# the callee (job_dispatch), we will kill the process group
# explicitly.
#
# Propagating the unsuccessful Exited message will kill the
# callee group. See job_dispatch.py.
process_group_id = os.getpgid(proc.pid)
this_group_id = os.getpgid(os.getpid())
if process_group_id != this_group_id:
os.killpg(process_group_id, signal.SIGKILL)

return Exited(self, proc.returncode).with_error(
(
f"Job:{self.name()} has been running "
f"for more than {max_running_minutes} "
"minutes - explicitly killed."
)
)

def _handle_process_io_error_and_create_exited_message(
self, e: OSError, stderr: io.TextIOWrapper | None
Expand All @@ -314,44 +313,19 @@ def _handle_process_io_error_and_create_exited_message(
stderr.write(msg)
return Exited(self, e.errno).with_error(msg)

def _assert_arg_list(self) -> list[str]:
errors: list[str] = []
if "arg_types" in self.job_data:
arg_types = self.job_data["arg_types"]
arg_list = self.job_data.get("argList")
for index, arg_type in enumerate(arg_types):
if arg_type == "RUNTIME_FILE":
file_path = os.path.join(os.getcwd(), arg_list[index])
if not os.path.isfile(file_path):
errors.append(
f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} "
"does not exist."
)
if arg_type == "RUNTIME_INT":
try:
int(arg_list[index])
except ValueError:
errors.append(
(
f"In job {self.name()}: argument with index {index} "
"is of incorrect type, should be integer."
)
)
return errors

def name(self) -> str:
return self.job_data["name"]

def _dump_exec_env(self) -> None:
exec_env = self.job_data.get("exec_env")
if exec_env:
exec_name, _ = os.path.splitext(
os.path.basename(self.job_data.get("executable"))
os.path.basename(cast(Path, self.job_data.get("executable")))
)
with open(f"{exec_name}_exec_env.json", "w", encoding="utf-8") as f_handle:
f_handle.write(json.dumps(exec_env, indent=4))

def _check_job_files(self)-> list[str]:
def _check_job_files(self) -> list[str]:
"""
Returns the empty list if no failed checks, or a list of errors in case
of failed checks.
Expand All @@ -361,21 +335,23 @@ def _check_job_files(self)-> list[str]:
errors.append(f'Could not locate stdin file: {self.job_data["stdin"]}')

if self.job_data.get("start_file") and not os.path.exists(
self.job_data["start_file"]
cast(Path, self.job_data["start_file"])
):
errors.append(f'Could not locate start_file:{self.job_data["start_file"]}')

if self.job_data.get("error_file") and os.path.exists(
self.job_data.get("error_file")
cast(Path, self.job_data.get("error_file"))
):
os.unlink(self.job_data.get("error_file"))
os.unlink(cast(Path, self.job_data.get("error_file")))

if executable_error := check_executable(self.job_data.get("executable")):
errors.append(str(executable_error))

return errors

def _check_target_file_is_written(self, target_file_mtime: int, timeout: int =5) -> None | str:
def _check_target_file_is_written(
self, target_file_mtime: int, timeout: int = 5
) -> None | str:
"""
Check whether or not a target_file eventually appear. Returns None in
case of success, an error message in the case of failure.
Expand Down
6 changes: 3 additions & 3 deletions src/_ert/forward_model_runner/reporting/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def with_error(self, message: str):
self.error_message = message
return self

def success(self):
def success(self) -> bool:
return self.error_message is None


Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self):


class Start(Message):
def __init__(self, job):
def __init__(self, job: "Job"):
super().__init__(job)


Expand All @@ -127,7 +127,7 @@ def __init__(self, job: "Job", memory_status: ProcessTreeStatus):


class Exited(Message):
def __init__(self, job, exit_code):
def __init__(self, job, exit_code: int):
super().__init__(job)
self.exit_code = exit_code

Expand Down

0 comments on commit b688ca8

Please sign in to comment.