diff --git a/src/slurm_plugin/instance_manager.py b/src/slurm_plugin/instance_manager.py index c7ee575ef..26846b3bc 100644 --- a/src/slurm_plugin/instance_manager.py +++ b/src/slurm_plugin/instance_manager.py @@ -9,7 +9,6 @@ # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. -import collections import itertools import logging @@ -349,7 +348,7 @@ def _parse_requested_nodes(self, node_list: List[str]) -> defaultdict[str, defau Valid NodeName format: {queue_name}-{st/dy}-{compute_resource_name}-{number} Sample NodeName: queue1-st-computeres1-2 """ - nodes_to_launch = collections.defaultdict(lambda: collections.defaultdict(list)) + nodes_to_launch = defaultdict(lambda: defaultdict(list)) for node in node_list: try: queue_name, node_type, compute_resource_name = parse_nodename(node) @@ -610,6 +609,20 @@ def _clear_unused_launched_instances(self): """Clear and reset unused launched instances list.""" self.unused_launched_instances = {} + def _update_dict(self, target_dict: dict, update: dict) -> dict: + logger.debug("Updating target dict (%s) with update (%s)", target_dict, update) + for update_key, update_value in update.items(): + if isinstance(update_value, dict): + target_dict[update_key] = self._update_dict(target_dict.get(update_key, {}), update_value) + elif isinstance(update_value, list): + target_dict[update_key] = target_dict.get(update_key, []) + update_value + elif isinstance(update_value, set): + target_dict[update_key] = target_dict.get(update_key, set()) | update_value + else: + target_dict[update_key] = update_value + logger.debug("Updated target dict is (%s)", target_dict) + return target_dict + def add_instances( self, node_list: List[str], @@ -783,10 +796,13 @@ def _scaling_for_jobs_multi_node( self, job_list, node_list, launch_batch_size, assign_node_batch_size, terminate_batch_size, update_node_address ): # Optimize job level scaling with preliminary scale-all nodes attempt - self.unused_launched_instances |= self._launch_instances( - nodes_to_launch=self._parse_requested_nodes(node_list), - launch_batch_size=launch_batch_size, - all_or_nothing_batch=True, + self._update_dict( + self.unused_launched_instances, + self._launch_instances( + nodes_to_launch=self._parse_requested_nodes(node_list), + launch_batch_size=launch_batch_size, + all_or_nothing_batch=True, + ), ) self._scaling_for_jobs( @@ -1001,7 +1017,7 @@ def _add_instances_for_job( ] ), ) - self.unused_launched_instances |= instances_launched + self._update_dict(self.unused_launched_instances, instances_launched) self._update_failed_nodes(set(parsed_requested_node), "LimitedInstanceCapacity", override=False) else: # No instances launched at all, e.g. CreateFleet API returns no EC2 instances @@ -1017,7 +1033,7 @@ def _launch_instances( all_or_nothing_batch: bool, job: SlurmResumeJob = None, ): - instances_launched = collections.defaultdict(lambda: collections.defaultdict(list)) + instances_launched = defaultdict(lambda: defaultdict(list)) for queue, compute_resources in nodes_to_launch.items(): for compute_resource, slurm_node_list in compute_resources.items(): slurm_node_list = self._resize_slurm_node_list( diff --git a/tests/slurm_plugin/test_instance_manager.py b/tests/slurm_plugin/test_instance_manager.py index be8fea8ec..075476d99 100644 --- a/tests/slurm_plugin/test_instance_manager.py +++ b/tests/slurm_plugin/test_instance_manager.py @@ -13,6 +13,7 @@ import os import re import subprocess +from collections import defaultdict from datetime import datetime, timezone from typing import Iterable from unittest.mock import call @@ -2437,7 +2438,8 @@ def test_update_slurm_node_addrs( @pytest.mark.parametrize( "job, launch_batch_size, assign_node_batch_size, update_node_address, all_or_nothing_batch, " - "expected_nodes_to_launch, mock_instances_launched, expect_assign_instances_to_nodes_called, " + "expected_nodes_to_launch, mock_instances_launched, initial_unused_launched_instances, " + "expected_unused_launched_instances, expect_assign_instances_to_nodes_called, " "expect_assign_instances_to_nodes_failure, expected_failed_nodes", [ ( @@ -2448,6 +2450,8 @@ def test_update_slurm_node_addrs( False, {}, {}, + {}, + {}, False, None, {}, @@ -2460,6 +2464,8 @@ def test_update_slurm_node_addrs( True, {"queue4": {"c5xlarge": ["queue4-st-c5xlarge-1"]}}, {}, + {}, + {}, False, None, {"InsufficientInstanceCapacity": {"queue4-st-c5xlarge-1"}}, @@ -2480,6 +2486,8 @@ def test_update_slurm_node_addrs( ] } }, + {}, + {}, True, None, {}, @@ -2500,6 +2508,8 @@ def test_update_slurm_node_addrs( ] } }, + {}, + {}, True, HostnameDnsStoreError(), {"Exception": {"queue4-st-c5xlarge-1"}}, @@ -2525,6 +2535,16 @@ def test_update_slurm_node_addrs( ] } }, + {}, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, False, None, {"LimitedInstanceCapacity": {"queue1-st-c5xlarge-1", "queue4-st-c5xlarge-1"}}, @@ -2542,6 +2562,8 @@ def test_update_slurm_node_addrs( True, {"queue1": {"c5xlarge": ["queue1-st-c5xlarge-1"]}, "queue4": {"c5xlarge": ["queue4-st-c5xlarge-1"]}}, {}, + {}, + {}, False, None, {"InsufficientInstanceCapacity": {"queue1-st-c5xlarge-1", "queue4-st-c5xlarge-1"}}, @@ -2559,6 +2581,8 @@ def test_update_slurm_node_addrs( True, {"queue1": {"c5xlarge": ["queue1-st-c5xlarge-1"]}}, {}, + {}, + {}, False, None, {"InsufficientInstanceCapacity": {"queue1-st-c5xlarge-1"}}, @@ -2584,10 +2608,108 @@ def test_update_slurm_node_addrs( ] } }, + {}, + {}, True, None, {}, ), + ( + SlurmResumeJob( + 140819, + "queue1-st-c5xlarge-1, queue4-st-c5xlarge-1", + "queue1-st-c5xlarge-1, queue4-st-c5xlarge-1", + "NO", + ), + 1, + 2, + False, + True, + {"queue1": {"c5xlarge": ["queue1-st-c5xlarge-1"]}, "queue4": {"c5xlarge": ["queue4-st-c5xlarge-1"]}}, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, + { + "queue10": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + }, + "queue10": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + }, + }, + False, + None, + {"LimitedInstanceCapacity": {"queue1-st-c5xlarge-1", "queue4-st-c5xlarge-1"}}, + ), + ( + SlurmResumeJob( + 140819, + "queue1-st-c5xlarge-1, queue4-st-c5xlarge-1", + "queue1-st-c5xlarge-1, queue4-st-c5xlarge-1", + "NO", + ), + 1, + 2, + False, + True, + {"queue1": {"c5xlarge": ["queue1-st-c5xlarge-1"]}, "queue4": {"c5xlarge": ["queue4-st-c5xlarge-1"]}}, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12346", "ip.1.0.0.6", "ip-1-0-0-6", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, + { + "queue4": { + "c5xlarge": [ + EC2Instance( + "i-12346", "ip.1.0.0.6", "ip-1-0-0-6", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + EC2Instance( + "i-12345", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ] + } + }, + False, + None, + {"LimitedInstanceCapacity": {"queue1-st-c5xlarge-1", "queue4-st-c5xlarge-1"}}, + ), ], ) def test_add_instances_for_job( @@ -2601,15 +2723,18 @@ def test_add_instances_for_job( all_or_nothing_batch, expected_nodes_to_launch, mock_instances_launched, + initial_unused_launched_instances, + expected_unused_launched_instances, expect_assign_instances_to_nodes_called, expect_assign_instances_to_nodes_failure, expected_failed_nodes, ): - # patch internal functions + # patch internal functions and data instance_manager._launch_instances = mocker.MagicMock(return_value=mock_instances_launched) instance_manager._assign_instances_to_nodes = mocker.MagicMock( side_effect=expect_assign_instances_to_nodes_failure ) + instance_manager.unused_launched_instances = initial_unused_launched_instances instance_manager._add_instances_for_job( job, launch_batch_size, assign_node_batch_size, update_node_address, all_or_nothing_batch @@ -2626,16 +2751,15 @@ def test_add_instances_for_job( all_or_nothing_batch=all_or_nothing_batch, ) + assert_that(instance_manager.unused_launched_instances).is_equal_to(expected_unused_launched_instances) if expect_assign_instances_to_nodes_called: instance_manager._assign_instances_to_nodes.assert_called_once() - assert_that(instance_manager.unused_launched_instances).is_empty() if expect_assign_instances_to_nodes_failure: assert_that(instance_manager.failed_nodes).is_equal_to(expected_failed_nodes) else: assert_that(instance_manager.failed_nodes).is_empty() else: instance_manager._assign_instances_to_nodes.assert_not_called() - assert_that(instance_manager.unused_launched_instances).is_equal_to(mock_instances_launched) assert_that(instance_manager.failed_nodes).is_equal_to(expected_failed_nodes) @pytest.mark.parametrize( @@ -3343,7 +3467,6 @@ def test_scaling_for_nodes( ) def test_resize_slurm_node_list( self, - mocker, instance_manager, queue, compute_resource, @@ -3367,6 +3490,142 @@ def test_resize_slurm_node_list( assert_that(new_slurm_node_list).is_equal_to(expected_slurm_node_list) assert_that(instances_launched).is_equal_to(expected_instances_launched) + @pytest.mark.parametrize( + "target_dict, update, expected_dict", + [ + ( + {}, + {}, + {}, + ), + ( + {"q1": {"c1": ["a", "b"]}}, + {}, + {"q1": {"c1": ["a", "b"]}}, + ), + ( + {"q1": {"c1": ["a", "b"]}}, + {"q1": {"c1": ["c"]}}, + {"q1": {"c1": ["a", "b", "c"]}}, + ), + ( + {}, + {"q1": {"c1": ["a", "b"]}}, + {"q1": {"c1": ["a", "b"]}}, + ), + ( + {"q1": {"c1": {"a", "b"}}}, + {}, + {"q1": {"c1": {"a", "b"}}}, + ), + ( + {"q1": {"c1": {"a", "b"}}}, + {"q1": {"c1": {"c"}}}, + {"q1": {"c1": {"a", "b", "c"}}}, + ), + ( + {}, + {"q1": {"c1": {"a", "b"}}}, + {"q1": {"c1": {"a", "b"}}}, + ), + ( + {"q1": {"c1": 1}}, + {}, + {"q1": {"c1": 1}}, + ), + ( + {"q1": {"c1": 1}}, + {"q1": {"c1": 2}}, + {"q1": {"c1": 2}}, + ), + ( + {}, + {"q1": {"c1": 3}}, + {"q1": {"c1": 3}}, + ), + ( + {"q1": {"c1": ["a", "b"], "c2": ["c"]}, "q2": {"c1": ["d"]}}, + {"q2": {"c1": ["k"]}, "q3": {"c1": ["y"]}}, + {"q1": {"c1": ["a", "b"], "c2": ["c"]}, "q2": {"c1": ["d", "k"]}, "q3": {"c1": ["y"]}}, + ), + ( + defaultdict(lambda: defaultdict(list)), + defaultdict(lambda: defaultdict(list)), + defaultdict(lambda: defaultdict(list)), + ), + ( + defaultdict(lambda: defaultdict(list)), + {"q1": {"c1": ["a", "b"]}}, + {"q1": {"c1": ["a", "b"]}}, + ), + ( + {"q1": {"c1": ["a", "b", "c"]}}, + defaultdict(lambda: defaultdict(list)), + {"q1": {"c1": ["a", "b", "c"]}}, + ), + ( + { + "q1": { + "c1": [ + EC2Instance("q1c1-1", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + "c2": [ + EC2Instance("q1c2-1", "ip.1.0.0.2", "ip-1-0-0-2", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + } + }, + {}, + { + "q1": { + "c1": [ + EC2Instance("q1c1-1", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + "c2": [ + EC2Instance("q1c2-1", "ip.1.0.0.2", "ip-1-0-0-2", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + } + }, + ), + ( + { + "q1": { + "c1": [ + EC2Instance("q1c1-1", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + "c2": [ + EC2Instance("q1c2-1", "ip.1.0.0.2", "ip-1-0-0-2", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + } + }, + { + "q1": { + "c2": [ + EC2Instance("q1c2-2", "ip.1.0.0.3", "ip-1-0-0-3", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ] + } + }, + { + "q1": { + "c1": [ + EC2Instance("q1c1-1", "ip.1.0.0.1", "ip-1-0-0-1", datetime(2020, 1, 1, tzinfo=timezone.utc)) + ], + "c2": [ + EC2Instance( + "q1c2-1", "ip.1.0.0.2", "ip-1-0-0-2", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + EC2Instance( + "q1c2-2", "ip.1.0.0.3", "ip-1-0-0-3", datetime(2020, 1, 1, tzinfo=timezone.utc) + ), + ], + } + }, + ), + ], + ) + def test_update_dict(self, instance_manager, target_dict, update, expected_dict): + actual_dict = instance_manager._update_dict(target_dict, update) + assert_that(actual_dict).is_equal_to(expected_dict) + class TestNodeListScalingInstanceManager: @pytest.fixture