Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[develop][wip-1] Job Level Scaling for Node Sharing case #558

Merged
merged 5 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This file is used to list changes made in each version of the aws-parallelcluste
**ENHANCEMENTS**

**CHANGES**
- WIP Perform default job-level scaling for all jobs, by reading job information from `SLURM_RESUME_FILE`.

**BUG FIXES**

Expand Down
12 changes: 7 additions & 5 deletions src/slurm_plugin/fleet_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import contextlib
import copy
import logging
import secrets
import time
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -383,15 +384,16 @@ def _get_instances_info(self, instance_ids: list):
instances = []
partial_instance_ids = instance_ids

retry = 4
retries = 5
attempt_count = 0
# Wait for instances to be available in EC2
time.sleep(0.1)
while retry > 0 and partial_instance_ids:
while attempt_count < retries and partial_instance_ids:
complete_instances, partial_instance_ids = self._retrieve_instances_info_from_ec2(partial_instance_ids)
instances.extend(complete_instances)
retry = retry - 1
if retry > 0:
time.sleep(0.3)
attempt_count += 1
if attempt_count < retries:
time.sleep(0.3 * 2**attempt_count + (secrets.randbelow(500) / 1000))

return instances, partial_instance_ids

Expand Down
87 changes: 64 additions & 23 deletions src/slurm_plugin/instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import collections
import itertools
import logging

Expand Down Expand Up @@ -72,6 +71,7 @@
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
job_level_scaling: bool = False,
temp_jls_for_node_sharing: bool = False,
):
if job_level_scaling:
return JobLevelScalingInstanceManager(
Expand All @@ -87,6 +87,7 @@
fleet_config=fleet_config,
run_instances_overrides=run_instances_overrides,
create_fleet_overrides=create_fleet_overrides,
temp_jls_for_node_sharing=temp_jls_for_node_sharing,
)
else:
return NodeListScalingInstanceManager(
Expand Down Expand Up @@ -347,7 +348,7 @@
Valid NodeName format: {queue_name}-{st/dy}-{compute_resource_name}-{number}
Sample NodeName: queue1-st-computeres1-2
"""
nodes_to_launch = collections.defaultdict(lambda: collections.defaultdict(list))
nodes_to_launch = defaultdict(lambda: defaultdict(list))
for node in node_list:
try:
queue_name, node_type, compute_resource_name = parse_nodename(node)
Expand Down Expand Up @@ -585,6 +586,7 @@
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
temp_jls_for_node_sharing: bool = False,
):
super().__init__(
region=region,
Expand All @@ -601,11 +603,26 @@
create_fleet_overrides=create_fleet_overrides,
)
self.unused_launched_instances = {}
self.temp_jls_for_node_sharing = temp_jls_for_node_sharing

def _clear_unused_launched_instances(self):
"""Clear and reset unused launched instances list."""
self.unused_launched_instances = {}

def _update_dict(self, target_dict: dict, update: dict) -> dict:
logger.debug("Updating target dict (%s) with update (%s)", target_dict, update)
for update_key, update_value in update.items():
if isinstance(update_value, dict):
target_dict[update_key] = self._update_dict(target_dict.get(update_key, {}), update_value)
elif isinstance(update_value, list):
target_dict[update_key] = target_dict.get(update_key, []) + update_value
elif isinstance(update_value, set):
target_dict[update_key] = target_dict.get(update_key, set()) | update_value
else:
target_dict[update_key] = update_value
logger.debug("Updated target dict is (%s)", target_dict)
return target_dict

def add_instances(
self,
node_list: List[str],
Expand Down Expand Up @@ -762,22 +779,30 @@
update_node_address=update_node_address,
)

# node scaling for oversubscribe nodes
self._scaling_for_nodes(
node_list=slurm_resume_data.nodes_oversubscribe,
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
)
if not self.temp_jls_for_node_sharing:
# node scaling for oversubscribe nodes
self._scaling_for_nodes(
node_list=list(
dict.fromkeys(
slurm_resume_data.single_node_oversubscribe + slurm_resume_data.multi_node_oversubscribe
)
),
launch_batch_size=launch_batch_size,
update_node_address=update_node_address,
all_or_nothing_batch=all_or_nothing_batch,
)

def _scaling_for_jobs_multi_node(
self, job_list, node_list, launch_batch_size, assign_node_batch_size, terminate_batch_size, update_node_address
):
# Optimize job level scaling with preliminary scale-all nodes attempt
self.unused_launched_instances |= self._launch_instances(
nodes_to_launch=self._parse_requested_nodes(node_list),
launch_batch_size=launch_batch_size,
all_or_nothing_batch=True,
self._update_dict(

Check warning on line 799 in src/slurm_plugin/instance_manager.py

View check run for this annotation

Codecov / codecov/patch

src/slurm_plugin/instance_manager.py#L799

Added line #L799 was not covered by tests
self.unused_launched_instances,
self._launch_instances(
nodes_to_launch=self._parse_requested_nodes(node_list),
launch_batch_size=launch_batch_size,
all_or_nothing_batch=True,
),
)

self._scaling_for_jobs(
Expand All @@ -795,7 +820,8 @@
SlurmResumeData object contains the following:
* the node list for jobs with oversubscribe != NO
* the node list for jobs with oversubscribe == NO
* the job list with oversubscribe != NO
* the job list with single node allocation with oversubscribe != NO
* the job list with multi node allocation with oversubscribe != NO
* the job list with single node allocation with oversubscribe == NO
* the job list with multi node allocation with oversubscribe == NO

Expand Down Expand Up @@ -838,10 +864,12 @@
"""
jobs_single_node_no_oversubscribe = []
jobs_multi_node_no_oversubscribe = []
jobs_oversubscribe = []
jobs_single_node_oversubscribe = []
jobs_multi_node_oversubscribe = []
single_node_no_oversubscribe = []
multi_node_no_oversubscribe = []
nodes_oversubscribe = []
single_node_oversubscribe = []
multi_node_oversubscribe = []

slurm_resume_jobs = self._parse_slurm_resume(slurm_resume)

Expand All @@ -854,12 +882,21 @@
jobs_multi_node_no_oversubscribe.append(job)
multi_node_no_oversubscribe.extend(job.nodes_resume)
else:
jobs_oversubscribe.append(job)
nodes_oversubscribe.extend(job.nodes_resume)
if len(job.nodes_resume) == 1:
jobs_single_node_oversubscribe.append(job)
single_node_oversubscribe.extend(job.nodes_resume)
else:
jobs_multi_node_oversubscribe.append(job)
multi_node_oversubscribe.extend(job.nodes_resume)

nodes_difference = list(
set(node_list)
- (set(nodes_oversubscribe) | set(single_node_no_oversubscribe) | set(multi_node_no_oversubscribe))
- (
set(single_node_oversubscribe)
| set(multi_node_oversubscribe)
| set(single_node_no_oversubscribe)
| set(multi_node_no_oversubscribe)
)
)
if nodes_difference:
logger.warning(
Expand All @@ -868,8 +905,12 @@
)
self._update_failed_nodes(set(nodes_difference), "InvalidNodenameError")
return SlurmResumeData(
nodes_oversubscribe=list(dict.fromkeys(nodes_oversubscribe)),
jobs_oversubscribe=jobs_oversubscribe,
# With Oversubscribe
single_node_oversubscribe=list(dict.fromkeys(single_node_oversubscribe)),
multi_node_oversubscribe=list(dict.fromkeys(multi_node_oversubscribe)),
jobs_single_node_oversubscribe=jobs_single_node_oversubscribe,
jobs_multi_node_oversubscribe=jobs_multi_node_oversubscribe,
# With No Oversubscribe
single_node_no_oversubscribe=single_node_no_oversubscribe,
multi_node_no_oversubscribe=multi_node_no_oversubscribe,
jobs_single_node_no_oversubscribe=jobs_single_node_no_oversubscribe,
Expand Down Expand Up @@ -976,7 +1017,7 @@
]
),
)
self.unused_launched_instances |= instances_launched
self._update_dict(self.unused_launched_instances, instances_launched)
self._update_failed_nodes(set(parsed_requested_node), "LimitedInstanceCapacity", override=False)
else:
# No instances launched at all, e.g. CreateFleet API returns no EC2 instances
Expand All @@ -992,7 +1033,7 @@
all_or_nothing_batch: bool,
job: SlurmResumeJob = None,
):
instances_launched = collections.defaultdict(lambda: collections.defaultdict(list))
instances_launched = defaultdict(lambda: defaultdict(list))
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
slurm_node_list = self._resize_slurm_node_list(
Expand Down
5 changes: 5 additions & 0 deletions src/slurm_plugin/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SlurmResumeConfig:
"fleet_config_file": "/etc/parallelcluster/slurm_plugin/fleet-config.json",
"all_or_nothing_batch": False,
"job_level_scaling": True,
"temp_jls_for_node_sharing": False,
}

def __init__(self, config_file_path):
Expand Down Expand Up @@ -95,6 +96,9 @@ def _get_config(self, config_file_path):
self.job_level_scaling = config.getboolean(
"slurm_resume", "job_level_scaling", fallback=self.DEFAULTS.get("job_level_scaling")
)
self.temp_jls_for_node_sharing = config.getboolean(
"slurm_resume", "temp_jls_for_node_sharing", fallback=self.DEFAULTS.get("temp_jls_for_node_sharing")
)
fleet_config_file = config.get(
"slurm_resume", "fleet_config_file", fallback=self.DEFAULTS.get("fleet_config_file")
)
Expand Down Expand Up @@ -197,6 +201,7 @@ def _resume(arg_nodes, resume_config, slurm_resume):
run_instances_overrides=resume_config.run_instances_overrides,
create_fleet_overrides=resume_config.create_fleet_overrides,
job_level_scaling=resume_config.job_level_scaling,
temp_jls_for_node_sharing=resume_config.temp_jls_for_node_sharing,
)
instance_manager.add_instances(
slurm_resume=slurm_resume,
Expand Down
12 changes: 8 additions & 4 deletions src/slurm_plugin/slurm_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,18 @@ class SlurmResumeData:
jobs_single_node_no_oversubscribe: List[SlurmResumeJob]
# List of exclusive job allocated to more than 1 node each
jobs_multi_node_no_oversubscribe: List[SlurmResumeJob]
# List of non-exclusive job
jobs_oversubscribe: List[SlurmResumeJob]
# List of non-exclusive job allocated to 1 node each
jobs_single_node_oversubscribe: List[SlurmResumeJob]
# List of non-exclusive job allocated to more than 1 node each
jobs_multi_node_oversubscribe: List[SlurmResumeJob]
# List of node allocated to single node exclusive job
single_node_no_oversubscribe: List[str]
# List of node allocated to multiple node exclusive job
multi_node_no_oversubscribe: List[str]
# List of node allocated to non-exclusive job
nodes_oversubscribe: List[str]
# List of node allocated to single node non-exclusive job
single_node_oversubscribe: List[str]
# List of node allocated to multiple node non-exclusive job
multi_node_oversubscribe: List[str]


class SlurmNode(metaclass=ABCMeta):
Expand Down
4 changes: 2 additions & 2 deletions tests/slurm_plugin/test_fleet_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def test_launch_instances(
generate_error=False,
),
]
+ 3
+ 4
* [
MockedBoto3Request(
method="describe_instances",
Expand Down Expand Up @@ -887,7 +887,7 @@ def test_launch_instances(
# client error
(
["i-12345"],
4
5
* [
MockedBoto3Request(
method="describe_instances",
Expand Down
Loading
Loading