Skip to content

Commit

Permalink
Document info/shared_cli, MemoryParamType int
Browse files Browse the repository at this point in the history
* Fleshed out documentation for the `gempyor.info/shared_cli` modules.
* Added the ability to optionally return the converted memory as an int
  instead of as a float to `MemoryParamType`.
* Removed pre-instantiated custom click param types.
  • Loading branch information
TimothyWillard committed Nov 15, 2024
1 parent 66ab33d commit 2ecab37
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 46 deletions.
21 changes: 7 additions & 14 deletions flepimop/gempyor_pkg/src/gempyor/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
from .logging import get_script_logger
from .utils import _format_cli_options, _git_checkout, _git_head, _shutil_which, config
from .shared_cli import (
MEMORY_MB,
NONNEGATIVE_DURATION,
DurationParamType,
MemoryParamType,
cli,
config_files_argument,
config_file_options,
config_files_argument,
log_cli_inputs,
mock_context,
parse_config_files,
Expand Down Expand Up @@ -903,13 +903,13 @@ def _submit_scenario_job(
),
click.Option(
param_decls=["--simulation-time", "simulation_time"],
type=NONNEGATIVE_DURATION,
type=DurationParamType(True),
default="3min",
help="The time limit per a simulation.",
),
click.Option(
param_decls=["--initial-time", "initial_time"],
type=NONNEGATIVE_DURATION,
type=DurationParamType(True),
default="20min",
help="The initialization time limit.",
),
Expand Down Expand Up @@ -971,7 +971,7 @@ def _submit_scenario_job(
),
click.Option(
param_decls=["--memory", "memory"],
type=MEMORY_MB,
type=MemoryParamType("mb", as_int=True),
default=None,
help="Override for the amount of memory per node to use in MB.",
),
Expand Down Expand Up @@ -1126,19 +1126,12 @@ def _click_submit(ctx: click.Context = mock_context, **kwargs: Any) -> None:
logger.info("Setting a total job time limit of %s minutes", job_time_limit.format())

# Job resources
memory = None if kwargs["memory"] is None else math.ceil(kwargs["memory"])
if memory != kwargs["memory"]:
logger.warning(
"The requested memory of %.3fMB has been rounded up to %uMB for submission",
kwargs["memory"],
memory,
)
job_resources = JobResources.from_presets(
job_size,
inference_method,
nodes=kwargs["nodes"],
cpus=kwargs["cpus"],
memory=memory,
memory=kwargs["memory"],
)
logger.info("Requesting the resources %s for this job.", job_resources)

Expand Down
74 changes: 70 additions & 4 deletions flepimop/gempyor_pkg/src/gempyor/info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
"""
Retrieving static information from developer managed yaml files.
Currently, it includes utilities for handling cluster-specific information, but it can
be extended to other categories as needed.
Classes:
Module: Represents a software module with a name and optional version.
PathExport: Represents a path export with a path, prepend flag, and error handling.
Cluster: Represents a cluster with a name, list of modules, and list of path
exports.
Functions:
get_cluster_info: Retrieves cluster-specific information.
Examples:
>>> from pprint import pprint
>>> from gempyor.info import get_cluster_info
>>> cluster_info = get_cluster_info("longleaf")
>>> cluster_info.name
'longleaf'
>>> pprint(cluster_info.modules)
[Module(name='gcc', version='9.1.0'),
Module(name='anaconda', version='2023.03'),
Module(name='git', version=None),
Module(name='aws', version=None)]
"""

__all__ = ["Cluster", "Module", "PathExport", "get_cluster_info"]


Expand All @@ -12,23 +40,55 @@


class Module(BaseModel):
"""
A model representing a module to load.
Attributes:
name: The name of the module to load.
version: The specific version of the module to load if there is one.
See Also:
[Lmod](https://lmod.readthedocs.io/en/latest/)
"""

name: str
version: str | None = None


class PathExport(BaseModel):
"""
A model representing the export path configuration.
Attributes:
path: The file system path of the path to add to the `$PATH` environment
variable.
prepend: A flag indicating whether to prepend additional information to the
`$PATH` environment variable.
error_if_missing: A flag indicating whether to raise an error if the path is
missing.
"""

path: Path
prepend: bool = True
error_if_missing: bool = False


class Cluster(BaseModel):
"""
A model representing a cluster configuration.
Attributes:
name: The name of the cluster.
modules: A list of modules associated with the cluster.
path_exports: A list of path exports for the cluster.
"""

name: str
modules: list[Module] = []
path_exports: list[PathExport] = []


T = TypeVar("T", bound=BaseModel)
_BASE_MODEL_TYPE = TypeVar("T", bound=BaseModel)


_CLUSTER_FQDN_REGEXES: tuple[tuple[str, Pattern], ...] = (
Expand All @@ -38,8 +98,8 @@ class Cluster(BaseModel):


def _get_info(
category: str, name: str, model: type[T], flepi_path: os.PathLike | None
) -> T:
category: str, name: str, model: type[_BASE_MODEL_TYPE], flepi_path: os.PathLike | None
) -> _BASE_MODEL_TYPE:
"""
Get and parse an information yaml file.
Expand Down Expand Up @@ -79,8 +139,14 @@ def get_cluster_info(name: str | None, flepi_path: os.PathLike | None = None) ->
flepi_path: Either a path like determine the directory to look for the info
directory in or `None` to use the `FLEPI_PATH` environment variable.
Returns
Returns:
An object containing the information about the `name` cluster.
Examples:
>>> from gempyor.info import get_cluster_info
>>> cluster_info = get_cluster_info("longleaf")
>>> cluster_info.name
'longleaf'
"""
name = _infer_cluster_from_fqdn() if name is None else name
return _get_info("cluster", name, Cluster, flepi_path)
Expand Down
114 changes: 98 additions & 16 deletions flepimop/gempyor_pkg/src/gempyor/shared_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@


from datetime import timedelta
from math import ceil
import multiprocessing
import pathlib
import re
Expand Down Expand Up @@ -143,6 +144,23 @@ def cli(ctx: click.Context) -> None:


class DurationParamType(click.ParamType):
"""
A custom Click parameter type for parsing duration strings into `timedelta` objects.
Attributes:
name: The name of the parameter type.
Examples:
>>> from gempyor.shared_cli import DurationParamType
>>> duration_param_type = DurationParamType(False)
>>> duration_param_type.convert("23min", None, None)
datetime.timedelta(seconds=1380)
>>> duration_param_type.convert("2.5hr", None, None)
datetime.timedelta(seconds=9000)
>>> duration_param_type.convert("-2", None, None)
datetime.timedelta(days=-1, seconds=86280)
"""

name = "duration"
_abbreviations = {
"s": "seconds",
Expand Down Expand Up @@ -173,6 +191,14 @@ def __init__(
nonnegative: bool,
default_unit: Literal["seconds", "minutes", "hours", "days", "weeks"] = "minutes",
) -> None:
"""
Initialize the instance based on parameter settings.
Args:
nonnegative: If `True` negative durations are not allowed.
default_unit: The default unit to use if no unit is specified in the input
string.
"""
super().__init__()
self._nonnegative = nonnegative
self._duration_regex = re.compile(
Expand All @@ -184,6 +210,24 @@ def __init__(
def convert(
self, value: Any, param: click.Parameter | None, ctx: click.Context | None
) -> timedelta:
"""
Converts a string representation of a duration into a `timedelta` object.
Args:
value: The value to convert, expected to be a string like representation of
a duration.
param: The Click parameter object for context in errors.
ctx: The Click context object for context in errors.
Returns:
The converted duration as a `timedelta` object.
Raises:
click.BadParameter: If the value is not a valid duration based on the
format.
click.BadParameter: If the duration is negative and the class was
initialized with `nonnegative` set to `True`.
"""
value = str(value).strip()
if (m := self._duration_regex.match(value)) is None:
self.fail(f"{value!r} is not a valid duration", param, ctx)
Expand All @@ -195,11 +239,24 @@ def convert(
return timedelta(**kwargs)


DURATION = DurationParamType(nonnegative=False)
NONNEGATIVE_DURATION = DurationParamType(nonnegative=True)
class MemoryParamType(click.ParamType):
"""
A custom Click parameter type for parsing duration strings into `timedelta` objects.
Attributes:
name: The name of the parameter type.
Examples:
>>> from gempyor.shared_cli import DurationParamType
>>> duration_param_type = DurationParamType(False)
>>> duration_param_type.convert("23min", None, None)
datetime.timedelta(seconds=1380)
>>> duration_param_type.convert("2.5hr", None, None)
datetime.timedelta(seconds=9000)
>>> duration_param_type.convert("-2", None, None)
datetime.timedelta(days=-1, seconds=86280)
"""

class MemoryParamType(click.ParamType):
name = "memory"
_units = {
"kb": 1024.0**1.0,
Expand All @@ -212,38 +269,63 @@ class MemoryParamType(click.ParamType):
"tb": 1024.0**4.0,
}

def __init__(self, unit: str) -> None:
def __init__(self, unit: str, as_int: bool = False) -> None:
"""
Initialize the instance based on parameter settings.
Args:
unit: The output unit to use in the `convert` method.
as_int: if `True` the `convert` method returns an integer instead of a
float.
Raises:
ValueError: If `unit` is not a valid memory unit size.
"""
super().__init__()
if (unit := unit.lower()) not in self._units.keys():
raise ValueError(
f"The `unit` given is not valid, given '{unit}' and "
"must be one of: {', '.join(self._units.keys())}."
f"must be one of: {', '.join(self._units.keys())}."
)
self._unit = unit
self._regex = re.compile(
rf"^(([0-9]+)?(\.[0-9]+)?)({'|'.join(self._units.keys())})?$",
flags=re.IGNORECASE,
)
self._as_int = as_int

def convert(
self, value: Any, param: click.Parameter | None, ctx: click.Context | None
) -> float:
) -> float | int:
"""
Converts a string representation of a memory size into a numeric.
Args:
value: The value to convert, expected to be a string like representation of
memory size.
param: The Click parameter object for context in errors.
ctx: The Click context object for context in errors.
Returns:
The converted memory size as a numeric. Specifically an integer if the
`as_int` attribute is `True` and float otherwise.
Raises:
click.BadParameter: If the value is not a valid memory size based on the
format.
"""
value = str(value).strip()
if (m := self._regex.match(value)) is None:
self.fail(f"{value!r} is not a valid memory size.", param, ctx)
number, _, _, unit = m.groups()
unit = unit.lower()
if unit == self._unit:
return float(number)
return (self._units.get(unit, self._unit) * float(number)) / (
self._units.get(self._unit)
)


MEMORY_KB = MemoryParamType("kb")
MEMORY_MB = MemoryParamType("mb")
MEMORY_GB = MemoryParamType("gb")
MEMORY_TB = MemoryParamType("tb")
result = float(number)
else:
result = (self._units.get(unit, self._unit) * float(number)) / (
self._units.get(self._unit)
)
return ceil(result) if self._as_int else result


def click_helpstring(
Expand Down
Loading

0 comments on commit 2ecab37

Please sign in to comment.