diff --git a/lib/charms/hpc_libs/v0/slurm_ops.py b/lib/charms/hpc_libs/v0/slurm_ops.py index 0dab9bf..c41458a 100644 --- a/lib/charms/hpc_libs/v0/slurm_ops.py +++ b/lib/charms/hpc_libs/v0/slurm_ops.py @@ -58,13 +58,13 @@ def _on_install(self, _) -> None: "format_key", "SlurmOpsError", "ServiceType", - "SlurmOpsManager", + "SlurmPackageManager", "ServiceManager", "ConfigurationManager", "MungekeyManager", "MungeManager", "SlurmManagerBase", - "SnapManager", + "SnapPackageManager", ] import json @@ -171,6 +171,11 @@ def active(self) -> bool: """Return True if the service is active.""" ... + @abstractmethod + def type(self) -> ServiceType: + """Return the service type of the managed service.""" + ... + class ConfigurationManager(ABC): """Control configuration of a Slurm component.""" @@ -204,7 +209,7 @@ def unset(self, *keys: str) -> None: class MungekeyManager(ABC): """Control the munge key.""" - def get_key(self) -> str: + def get(self) -> str: """Get the current munge key. Returns: @@ -212,7 +217,7 @@ def get_key(self) -> str: """ ... - def set_key(self, key: str) -> None: + def set(self, key: str) -> None: """Set a new munge key. Args: @@ -220,12 +225,12 @@ def set_key(self, key: str) -> None: """ ... - def generate_key(self) -> None: + def generate(self) -> None: """Generate a new, cryptographically secure munge key.""" ... -class SlurmOpsManager(ABC): +class SlurmPackageManager(ABC): """Manager to control the installation, creation and configuration of Slurm-related services.""" @abstractmethod @@ -257,23 +262,23 @@ def mungekey_manager(self) -> MungekeyManager: class MungeManager: """Manage `munged` service operations.""" - def __init__(self, ops_manager: SlurmOpsManager) -> None: + def __init__(self, ops_manager: SlurmPackageManager) -> None: self.service = ops_manager.service_for(ServiceType.MUNGED) self.config = ops_manager.configuration_manager_for(ServiceType.MUNGED) - self.mungekey = ops_manager.mungekey_manager() + self.key = ops_manager.mungekey_manager() class PrometheusExporterManager: """Manage `slurm-prometheus-exporter` service operations.""" - def __init__(self, ops_manager: SlurmOpsManager) -> None: + def __init__(self, ops_manager: SlurmPackageManager) -> None: self.service = ops_manager.service_for(ServiceType.PROMETHEUS_EXPORTER) class SlurmManagerBase: """Base manager for Slurm services.""" - def __init__(self, service: ServiceType, ops_manager: SlurmOpsManager) -> None: + def __init__(self, service: ServiceType, ops_manager: SlurmPackageManager) -> None: self.service = ops_manager.service_for(service) self.config = ops_manager.configuration_manager_for(service) self.munge = MungeManager(ops_manager) @@ -286,7 +291,7 @@ def hostname(self) -> str: return socket.gethostname().split(".")[0] -class SnapManager(SlurmOpsManager): +class SnapPackageManager(SlurmPackageManager): """Slurm ops manager that uses Snap as its package manager.""" def install(self) -> None: @@ -343,6 +348,10 @@ def active(self) -> bool: # We don't do `"active" in state` because the word "active" is also part of "inactive" :) return "inactive" not in services[f"slurm.{self._service.value}"] + def type(self) -> ServiceType: + """Return the service type of the managed service.""" + return self._service + class _SnapConfigurationManager(ConfigurationManager): """Control configuration of a Slurm component using Snap.""" @@ -382,7 +391,7 @@ def _mungectl(self, *args: str, stdin: Optional[str] = None) -> str: """ return _call("slurm.mungectl", *args, stdin=stdin) - def get_key(self) -> str: + def get(self) -> str: """Get the current munge key. Returns: @@ -390,7 +399,7 @@ def get_key(self) -> str: """ return self._mungectl("key", "get") - def set_key(self, key: str) -> None: + def set(self, key: str) -> None: """Set a new munge key. Args: @@ -398,7 +407,7 @@ def set_key(self, key: str) -> None: """ self._mungectl("key", "set", stdin=key) - def generate_key(self) -> None: + def generate(self) -> None: """Generate a new, cryptographically secure munge key.""" self._mungectl("key", "generate") diff --git a/tests/integration/slurm_ops/test_manager.py b/tests/integration/slurm_ops/test_manager.py index bc77dfc..51f0b33 100644 --- a/tests/integration/slurm_ops/test_manager.py +++ b/tests/integration/slurm_ops/test_manager.py @@ -7,32 +7,32 @@ import pytest import lib.charms.hpc_libs.v0.slurm_ops as slurm -from lib.charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase +from lib.charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase, SnapPackageManager @pytest.fixture def slurmctld() -> SlurmManagerBase: - return SlurmManagerBase(ServiceType.SLURMCTLD) + return SlurmManagerBase(ServiceType.SLURMCTLD, SnapPackageManager()) @pytest.mark.order(1) def test_install(slurmctld: SlurmManagerBase) -> None: """Install Slurm using the manager.""" - slurm.install() - slurmctld.munge.generate_key() + slurmctld.ops.install() + slurmctld.munge.key.generate() with open("/var/snap/slurm/common/etc/munge/munge.key", "rb") as f: key: str = base64.b64encode(f.read()).decode() - assert key == slurmctld.munge.get_key() + assert key == slurmctld.munge.key.get() @pytest.mark.order(2) def test_rotate_key(slurmctld: SlurmManagerBase) -> None: """Test that the munge key can be rotated.""" - old_key = slurmctld.munge.get_key() - slurmctld.munge.generate_key() - new_key = slurmctld.munge.get_key() + old_key = slurmctld.munge.key.get() + slurmctld.munge.key.generate() + new_key = slurmctld.munge.key.get() assert old_key != new_key @@ -58,14 +58,14 @@ def test_slurm_config(slurmctld: SlurmManagerBase) -> None: @pytest.mark.order(4) def test_enable_service(slurmctld: SlurmManagerBase) -> None: """Test that the slurmctl daemon can be enabled.""" - slurmctld.enable() - assert slurmctld.active() + slurmctld.service.enable() + assert slurmctld.service.active() @pytest.mark.order(5) -def test_version() -> None: +def test_version(slurmctld: SlurmManagerBase) -> None: """Test that the Slurm manager can report its version.""" - version = slurm.version() + version = slurmctld.ops.version() # We are interested in knowing that this does not return a falsy value (`None`, `''`, `[]`, etc.) assert version diff --git a/tests/unit/test_slurm_ops.py b/tests/unit/test_slurm_ops.py index 6d12d77..8fc6435 100644 --- a/tests/unit/test_slurm_ops.py +++ b/tests/unit/test_slurm_ops.py @@ -10,7 +10,12 @@ from unittest.mock import patch import charms.hpc_libs.v0.slurm_ops as slurm -from charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase, SlurmOpsError +from charms.hpc_libs.v0.slurm_ops import ( + ServiceType, + SlurmManagerBase, + SlurmOpsError, + SnapPackageManager, +) MUNGEKEY = b"1234567890" MUNGEKEY_BASE64 = base64.b64encode(MUNGEKEY) @@ -60,15 +65,25 @@ @patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output") class TestSlurmOps(TestCase): - def test_format_key(self, _) -> None: """Test that `kebabize` properly formats slurm keys.""" self.assertEqual(slurm.format_key("CPUs"), "cpus") self.assertEqual(slurm.format_key("AccountingStorageHost"), "accounting-storage-host") + def test_error_message(self, *_) -> None: + """Test that `SlurmOpsError` stores the correct message.""" + message = "error message!" + self.assertEqual(SlurmOpsError(message).message, message) + + +@patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output") +class TestSnapPackageManager(TestCase): + def setUp(self): + self.manager = SnapPackageManager() + def test_install(self, subcmd) -> None: """Test that `slurm_ops` calls the correct install command.""" - slurm.install() + self.manager.install() args = subcmd.call_args[0][0] self.assertEqual(args[:3], ["snap", "install", "slurm"]) self.assertIn("--classic", args[3:]) # codespell:ignore @@ -76,7 +91,7 @@ def test_install(self, subcmd) -> None: def test_version(self, subcmd) -> None: """Test that `slurm_ops` gets the correct version using the correct command.""" subcmd.return_value = SLURM_INFO.encode() - version = slurm.version() + version = self.manager.version() args = subcmd.call_args[0][0] self.assertEqual(args, ["snap", "info", "slurm"]) self.assertEqual(version, "23.11.7") @@ -85,7 +100,7 @@ def test_version_not_installed(self, subcmd) -> None: """Test that `slurm_ops` throws when getting the installed version if the slurm snap is not installed.""" subcmd.return_value = SLURM_INFO_NOT_INSTALLED.encode() with self.assertRaises(slurm.SlurmOpsError): - slurm.version() + self.manager.version() args = subcmd.call_args[0][0] self.assertEqual(args, ["snap", "info", "slurm"]) @@ -93,12 +108,7 @@ def test_call_error(self, subcmd) -> None: """Test that `slurm_ops` propagates errors when a command fails.""" subcmd.side_effect = subprocess.CalledProcessError(-1, cmd=[""], stderr=b"error") with self.assertRaises(slurm.SlurmOpsError): - slurm.install() - - def test_error_message(self, *_) -> None: - """Test that `SlurmOpsError` stores the correct message.""" - message = "error message!" - self.assertEqual(SlurmOpsError(message).message, message) + self.manager.install() @patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output") @@ -107,43 +117,43 @@ class SlurmOpsBase: def test_config_name(self, *_) -> None: """Test that the config name is correctly set.""" - self.assertEqual(self.manager._service.config_name, self.config_name) + self.assertEqual(self.manager.service.type().config_name, self.config_name) def test_enable(self, subcmd, *_) -> None: """Test that the manager calls the correct enable command.""" - self.manager.enable() + self.manager.service.enable() args = subcmd.call_args[0][0] self.assertEqual( - args, ["snap", "start", "--enable", f"slurm.{self.manager._service.value}"] + args, ["snap", "start", "--enable", f"slurm.{self.manager.service.type().value}"] ) def test_disable(self, subcmd, *_) -> None: """Test that the manager calls the correct disable command.""" - self.manager.disable() + self.manager.service.disable() args = subcmd.call_args[0][0] self.assertEqual( - args, ["snap", "stop", "--disable", f"slurm.{self.manager._service.value}"] + args, ["snap", "stop", "--disable", f"slurm.{self.manager.service.type().value}"] ) def test_restart(self, subcmd, *_) -> None: """Test that the manager calls the correct restart command.""" - self.manager.restart() + self.manager.service.restart() args = subcmd.call_args[0][0] - self.assertEqual(args, ["snap", "restart", f"slurm.{self.manager._service.value}"]) + self.assertEqual(args, ["snap", "restart", f"slurm.{self.manager.service.type().value}"]) def test_active(self, subcmd, *_) -> None: """Test that the manager can detect that a service is active.""" subcmd.return_value = SLURM_INFO.encode() - self.assertTrue(self.manager.active()) + self.assertTrue(self.manager.service.active()) def test_active_not_installed(self, subcmd, *_) -> None: """Test that the manager throws an error when calling `active` if the snap is not installed.""" subcmd.return_value = SLURM_INFO_NOT_INSTALLED.encode() with self.assertRaises(slurm.SlurmOpsError): - self.manager.active() + self.manager.service.active() args = subcmd.call_args[0][0] self.assertEqual(args, ["snap", "info", "slurm"]) @@ -194,13 +204,13 @@ def test_unset_config_all(self, subcmd) -> None: def test_generate_munge_key(self, subcmd, *_) -> None: """Test that the manager calls the correct `mungectl` command.""" - self.manager.munge.generate_key() + self.manager.munge.key.generate() args = subcmd.call_args[0][0] self.assertEqual(args, ["slurm.mungectl", "key", "generate"]) def test_set_munge_key(self, subcmd, *_) -> None: """Test that the manager sets the munge key with the correct command.""" - self.manager.munge.set_key(MUNGEKEY_BASE64) + self.manager.munge.key.set(MUNGEKEY_BASE64) args = subcmd.call_args[0][0] # MUNGEKEY_BASE64 is piped to `stdin` to avoid exposure. self.assertEqual(args, ["slurm.mungectl", "key", "set"]) @@ -208,7 +218,7 @@ def test_set_munge_key(self, subcmd, *_) -> None: def test_get_munge_key(self, subcmd, *_) -> None: """Test that the manager gets the munge key with the correct command.""" subcmd.return_value = MUNGEKEY_BASE64 - key = self.manager.munge.get_key() + key = self.manager.munge.key.get() args = subcmd.call_args[0][0] self.assertEqual(args, ["slurm.mungectl", "key", "get"]) self.assertEqual(key, MUNGEKEY_BASE64) @@ -229,14 +239,14 @@ def test_hostname(self, gethostname, *_) -> None: parameters = [ - (SlurmManagerBase(ServiceType.SLURMCTLD), "slurm"), - (SlurmManagerBase(ServiceType.SLURMD), "slurmd"), - (SlurmManagerBase(ServiceType.SLURMDBD), "slurmdbd"), - (SlurmManagerBase(ServiceType.SLURMRESTD), "slurmrestd"), + (SlurmManagerBase(ServiceType.SLURMCTLD, SnapPackageManager()), "slurm"), + (SlurmManagerBase(ServiceType.SLURMD, SnapPackageManager()), "slurmd"), + (SlurmManagerBase(ServiceType.SLURMDBD, SnapPackageManager()), "slurmdbd"), + (SlurmManagerBase(ServiceType.SLURMRESTD, SnapPackageManager()), "slurmrestd"), ] for manager, config_name in parameters: - cls_name = f"Test{manager._service.value.capitalize()}Ops" + cls_name = f"Test{manager.service.type().name.capitalize()}Ops" globals()[cls_name] = type( cls_name, (SlurmOpsBase, TestCase),