diff --git a/pegasus_wrapper/key_value.py b/pegasus_wrapper/key_value.py index cb8365e..86304c8 100644 --- a/pegasus_wrapper/key_value.py +++ b/pegasus_wrapper/key_value.py @@ -79,10 +79,7 @@ def __call__( final_transform = self.transforms[-1] cur_value = input_zip for transform in self.transforms: - if transform is final_transform: - step_output_locator = output_locator - else: - step_output_locator = None + step_output_locator = output_locator if transform is final_transform else None cur_value = transform(cur_value, output_locator=step_output_locator) return cur_value diff --git a/pegasus_wrapper/resource_request.py b/pegasus_wrapper/resource_request.py index 77da0fd..4016fdf 100644 --- a/pegasus_wrapper/resource_request.py +++ b/pegasus_wrapper/resource_request.py @@ -15,6 +15,35 @@ SCAVENGE = "scavenge" EPHEMERAL = "ephemeral" +_SLURM_DEFAULT_MEMORY = MemoryAmount.parse("2G") +_PROJECT_PARTITION_JOB_TIME_IN_MINUTES = 1440 + + +@attrs(frozen=True, slots=True) +class Partition: + """ + Representation of a SAGA partition + """ + + name: str = attrib(validator=instance_of(str)) + max_walltime: int = attrib(validator=instance_of(int), kw_only=True) + + def __eq__(self, other) -> bool: + return self.name == other.name + + def __str__(self) -> str: + return self.name + + @staticmethod + def from_str(name: str): + _partition_to_max_walltime = {"ephemeral": 720, "scavenge": 60} + + return Partition( + name=name, + max_walltime=_partition_to_max_walltime.get( + name, _PROJECT_PARTITION_JOB_TIME_IN_MINUTES + ), + ) class ResourceRequest(Protocol): @@ -61,22 +90,20 @@ def from_parameters(params: Parameters) -> "ResourceRequest": raise RuntimeError(f"Invalid backend option {backend}") -_SLURM_DEFAULT_MEMORY = MemoryAmount.parse("2G") -_DEFAULT_JOB_TIME_IN_MINUTES = 1440 - - @attrs(frozen=True, slots=True) class SlurmResourceRequest(ResourceRequest): """ A `ResourceRequest` for a job running on a SLURM cluster. """ + partition: Optional[Partition] = attrib( + converter=lambda x: Partition.from_str(x) if x else None, + kw_only=True, + default=None, + ) memory: Optional[MemoryAmount] = attrib( validator=optional(instance_of(MemoryAmount)), kw_only=True, default=None ) - partition: Optional[str] = attrib( - validator=optional(instance_of(str)), kw_only=True, default=None - ) num_cpus: Optional[int] = attrib( validator=optional(in_(Range.at_least(1))), default=None, kw_only=True ) @@ -84,9 +111,7 @@ class SlurmResourceRequest(ResourceRequest): validator=optional(in_(Range.at_least(0))), default=None, kw_only=True ) job_time_in_minutes: Optional[int] = attrib( - validator=optional(instance_of(int)), - default=_DEFAULT_JOB_TIME_IN_MINUTES, - kw_only=True, + validator=optional(instance_of(int)), default=None, kw_only=True ) exclude_list: Optional[str] = attrib( validator=optional(instance_of(str)), kw_only=True, default=None @@ -95,6 +120,24 @@ class SlurmResourceRequest(ResourceRequest): validator=optional(instance_of(str)), kw_only=True, default=None ) + def __attrs_post_init__(self): + if not self.job_time_in_minutes: + partition_job_time = None + if not self.partition: + logging.warning( + "Could not find selected partition. Setting job with no job time specified to max project partition walltime." + ) + partition_job_time = _PROJECT_PARTITION_JOB_TIME_IN_MINUTES + else: + logging.warning( + "Defaulting job with no job time specified to max walltime of selected partition '%s'", + self.partition.name, + ) + partition_job_time = self.partition.max_walltime + # Workaround suggested by maintainers of attrs. + # See https://www.attrs.org/en/stable/how-does-it-work.html#how-frozen + object.__setattr__(self, "job_time_in_minutes", partition_job_time) + @run_on_single_node.validator def check(self, _, value: str): if value and len(value.split(",")) != 1: @@ -116,22 +159,18 @@ def from_parameters(params: Parameters) -> ResourceRequest: def unify(self, other: ResourceRequest) -> ResourceRequest: if isinstance(other, SlurmResourceRequest): - partition = other.partition if other.partition else self.partition + partition = other.partition or self.partition else: partition = self.partition return SlurmResourceRequest( - partition=partition, - memory=other.memory if other.memory else self.memory, - num_cpus=other.num_cpus if other.num_cpus else self.num_cpus, + partition=partition.name, + memory=other.memory or self.memory, + num_cpus=other.num_cpus or self.num_cpus, num_gpus=other.num_gpus if other.num_gpus is not None else self.num_gpus, - job_time_in_minutes=other.job_time_in_minutes - if other.job_time_in_minutes - else self.job_time_in_minutes, - exclude_list=other.exclude_list if other.exclude_list else self.exclude_list, - run_on_single_node=other.run_on_single_node - if other.run_on_single_node - else self.run_on_single_node, + job_time_in_minutes=other.job_time_in_minutes or self.job_time_in_minutes, + exclude_list=other.exclude_list or self.exclude_list, + run_on_single_node=other.run_on_single_node or self.run_on_single_node, ) def convert_time_to_slurm_format(self, job_time_in_minutes: int) -> str: @@ -142,25 +181,24 @@ def apply_to_job(self, job: Job, *, job_name: str) -> None: if not self.partition: raise RuntimeError("A partition to run on must be specified.") + if self.partition.max_walltime < self.job_time_in_minutes: + raise ValueError( + f"Partition '{self.partition.name}' has a max walltime of {self.partition.max_walltime} mins, which is less than the time given ({self.job_time_in_minutes} mins) for job: {job_name}." + ) + qos_or_account = ( - f"qos {self.partition}" - if self.partition in (SCAVENGE, EPHEMERAL) - else f"account {self.partition}" + f"qos {self.partition.name}" + if self.partition.name in (SCAVENGE, EPHEMERAL) + else f"account {self.partition.name}" ) slurm_resource_content = SLURM_RESOURCE_STRING.format( qos_or_account=qos_or_account, - partition=self.partition, - num_cpus=self.num_cpus if self.num_cpus else 1, + partition=self.partition.name, + num_cpus=self.num_cpus or 1, num_gpus=self.num_gpus if self.num_gpus is not None else 0, job_name=job_name, - mem_str=to_slurm_memory_string( - self.memory if self.memory else _SLURM_DEFAULT_MEMORY - ), - time=self.convert_time_to_slurm_format( - self.job_time_in_minutes - if self.job_time_in_minutes - else _DEFAULT_JOB_TIME_IN_MINUTES - ), + mem_str=to_slurm_memory_string(self.memory or _SLURM_DEFAULT_MEMORY), + time=self.convert_time_to_slurm_format(self.job_time_in_minutes), ) if ( diff --git a/pegasus_wrapper/workflow.py b/pegasus_wrapper/workflow.py index 80bd37b..003c21d 100644 --- a/pegasus_wrapper/workflow.py +++ b/pegasus_wrapper/workflow.py @@ -192,10 +192,7 @@ def run_python_on_parameters( for out_file in parent_dependency.output_files: job.uses(out_file, link=Link.INPUT) - if resource_request is not None: - resource_request = self.default_resource_request.unify(resource_request) - else: - resource_request = self.default_resource_request + resource_request = self.set_resource_request(resource_request) if category: job.profile(Namespace.DAGMAN, "category", category) @@ -230,6 +227,14 @@ def run_python_on_parameters( logging.info("Scheduled Python job %s", job_name) return dependency_node + def set_resource_request(self, resource_request: ResourceRequest): + if resource_request is not None: + resource_request = self.default_resource_request.unify(resource_request) + else: + resource_request = self.default_resource_request + + return resource_request + def limit_jobs_for_category(self, category: str, max_jobs: int): """ Limit the number of jobs in the given category that can run concurrently to max_jobs. diff --git a/tests/key_value_test.py b/tests/key_value_test.py index 16f9108..10a0b26 100644 --- a/tests/key_value_test.py +++ b/tests/key_value_test.py @@ -27,7 +27,7 @@ def subtract2(values, **kwargs): # pylint:disable=unused-argument "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", } ) diff --git a/tests/workflow_builder_test.py b/tests/workflow_builder_test.py index c6f9b89..cee4059 100644 --- a/tests/workflow_builder_test.py +++ b/tests/workflow_builder_test.py @@ -27,7 +27,7 @@ def test_simple_dax(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", "experiment_name": "fred", } ) @@ -65,12 +65,12 @@ def test_dax_with_job_on_saga(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", "experiment_name": "fred", } ) slurm_params = Parameters.from_mapping( - {"partition": "scavenge", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} + {"partition": "gaia", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) multiply_input_file = tmp_path / "raw_nums.txt" random = Random() @@ -158,11 +158,11 @@ def test_dax_with_checkpointed_jobs_on_saga(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", } ) slurm_params = Parameters.from_mapping( - {"partition": "scavenge", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} + {"partition": "gaia", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) resources = SlurmResourceRequest.from_parameters(slurm_params) workflow_builder = WorkflowBuilder.from_parameters(workflow_params) @@ -332,7 +332,8 @@ def endElement(self, name): # the expected category. elif name == "profile" and self._in_target_job_category: category = "".join(self._job_category_content).strip() - if category == self.category: + # category will always be a string, need to convert any object or non-str to compare + if category == str(self.category): self._job_has_category = True self._job_category_content = [] self._in_target_job_category = False @@ -367,7 +368,7 @@ def test_dax_with_categories(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", } ) workflow_builder = WorkflowBuilder.from_parameters(workflow_params) @@ -407,10 +408,10 @@ def test_dax_with_saga_categories(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", } ) - multiply_partition = "scavenge" + multiply_partition = "gaia" multiply_slurm_params = Parameters.from_mapping( {"partition": multiply_partition, "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) @@ -436,7 +437,7 @@ def test_dax_with_saga_categories(tmp_path): locator=Locator("multiply"), ) - sort_partition = "ephemeral" + sort_partition = "lestat" sort_slurm_params = Parameters.from_mapping( {"partition": sort_partition, "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) @@ -475,11 +476,11 @@ def test_category_max_jobs(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", } ) multiply_slurm_params = Parameters.from_mapping( - {"partition": "scavenge", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} + {"partition": "gaia", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) multiply_resources = SlurmResourceRequest.from_parameters(multiply_slurm_params) workflow_builder = WorkflowBuilder.from_parameters(workflow_params) @@ -504,7 +505,13 @@ def test_category_max_jobs(tmp_path): ) sort_slurm_params = Parameters.from_mapping( - {"partition": "ephemeral", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} + { + "partition": "ephemeral", + "num_cpus": 1, + "num_gpus": 0, + "memory": "4G", + "job_time_in_minutes": 120, + } ) sort_resources = SlurmResourceRequest.from_parameters(sort_slurm_params) @@ -521,7 +528,7 @@ def test_category_max_jobs(tmp_path): resource_request=sort_resources, ) - workflow_builder.limit_jobs_for_category("scavenge", 1) + workflow_builder.limit_jobs_for_category("gaia", 1) workflow_builder.write_dax_to_dir() config = workflow_params.existing_directory("workflow_directory") / "pegasus.conf" @@ -530,8 +537,8 @@ def test_category_max_jobs(tmp_path): # Make sure the config contains the appropriate maxjobs lines and no inappropriate maxjobs lines with config.open("r") as f: lines = f.readlines() - assert any(["dagman.scavenge.maxjobs=1" in line for line in lines]) - assert not any(["dagman.ephemeral.maxjobs=" in line for line in lines]) + assert any("dagman.gaia.maxjobs=1" in line for line in lines) + assert all("dagman.ephemeral.maxjobs=" not in line for line in lines) def test_dax_test_exclude_nodes_on_saga(tmp_path): @@ -547,12 +554,12 @@ def test_dax_test_exclude_nodes_on_saga(tmp_path): "workflow_directory": str(tmp_path / "working"), "site": "saga", "namespace": "test", - "partition": "scavenge", + "partition": "gaia", "exclude_list": sample_exclude, } ) slurm_params = Parameters.from_mapping( - {"partition": "scavenge", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} + {"partition": "gaia", "num_cpus": 1, "num_gpus": 0, "memory": "4G"} ) multiply_input_file = tmp_path / "raw_nums.txt" random = Random()