Skip to content

Commit

Permalink
Rewrite tests to new design
Browse files Browse the repository at this point in the history
  • Loading branch information
jedel1043 committed Aug 20, 2024
1 parent 4f6fd43 commit 336e2b8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 54 deletions.
37 changes: 23 additions & 14 deletions lib/charms/hpc_libs/v0/slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def _on_install(self, _) -> None:
"format_key",
"SlurmOpsError",
"ServiceType",
"SlurmOpsManager",
"SlurmPackageManager",
"ServiceManager",
"ConfigurationManager",
"MungekeyManager",
"MungeManager",
"SlurmManagerBase",
"SnapManager",
"SnapPackageManager",
]

import json
Expand Down Expand Up @@ -171,6 +171,11 @@ def active(self) -> bool:
"""Return True if the service is active."""
...

@abstractmethod
def type(self) -> ServiceType:
"""Return the service type of the managed service."""
...


class ConfigurationManager(ABC):
"""Control configuration of a Slurm component."""
Expand Down Expand Up @@ -204,28 +209,28 @@ def unset(self, *keys: str) -> None:
class MungekeyManager(ABC):
"""Control the munge key."""

def get_key(self) -> str:
def get(self) -> str:
"""Get the current munge key.
Returns:
The current munge key as a base64-encoded string.
"""
...

def set_key(self, key: str) -> None:
def set(self, key: str) -> None:
"""Set a new munge key.
Args:
key: A new, base64-encoded munge key.
"""
...

def generate_key(self) -> None:
def generate(self) -> None:
"""Generate a new, cryptographically secure munge key."""
...


class SlurmOpsManager(ABC):
class SlurmPackageManager(ABC):
"""Manager to control the installation, creation and configuration of Slurm-related services."""

@abstractmethod
Expand Down Expand Up @@ -257,23 +262,23 @@ def mungekey_manager(self) -> MungekeyManager:
class MungeManager:
"""Manage `munged` service operations."""

def __init__(self, ops_manager: SlurmOpsManager) -> None:
def __init__(self, ops_manager: SlurmPackageManager) -> None:
self.service = ops_manager.service_for(ServiceType.MUNGED)
self.config = ops_manager.configuration_manager_for(ServiceType.MUNGED)
self.mungekey = ops_manager.mungekey_manager()
self.key = ops_manager.mungekey_manager()


class PrometheusExporterManager:
"""Manage `slurm-prometheus-exporter` service operations."""

def __init__(self, ops_manager: SlurmOpsManager) -> None:
def __init__(self, ops_manager: SlurmPackageManager) -> None:
self.service = ops_manager.service_for(ServiceType.PROMETHEUS_EXPORTER)


class SlurmManagerBase:
"""Base manager for Slurm services."""

def __init__(self, service: ServiceType, ops_manager: SlurmOpsManager) -> None:
def __init__(self, service: ServiceType, ops_manager: SlurmPackageManager) -> None:
self.service = ops_manager.service_for(service)
self.config = ops_manager.configuration_manager_for(service)
self.munge = MungeManager(ops_manager)
Expand All @@ -286,7 +291,7 @@ def hostname(self) -> str:
return socket.gethostname().split(".")[0]


class SnapManager(SlurmOpsManager):
class SnapPackageManager(SlurmPackageManager):
"""Slurm ops manager that uses Snap as its package manager."""

def install(self) -> None:
Expand Down Expand Up @@ -343,6 +348,10 @@ def active(self) -> bool:
# We don't do `"active" in state` because the word "active" is also part of "inactive" :)
return "inactive" not in services[f"slurm.{self._service.value}"]

def type(self) -> ServiceType:
"""Return the service type of the managed service."""
return self._service


class _SnapConfigurationManager(ConfigurationManager):
"""Control configuration of a Slurm component using Snap."""
Expand Down Expand Up @@ -382,23 +391,23 @@ def _mungectl(self, *args: str, stdin: Optional[str] = None) -> str:
"""
return _call("slurm.mungectl", *args, stdin=stdin)

def get_key(self) -> str:
def get(self) -> str:
"""Get the current munge key.
Returns:
The current munge key as a base64-encoded string.
"""
return self._mungectl("key", "get")

def set_key(self, key: str) -> None:
def set(self, key: str) -> None:
"""Set a new munge key.
Args:
key: A new, base64-encoded munge key.
"""
self._mungectl("key", "set", stdin=key)

def generate_key(self) -> None:
def generate(self) -> None:
"""Generate a new, cryptographically secure munge key."""
self._mungectl("key", "generate")

Expand Down
24 changes: 12 additions & 12 deletions tests/integration/slurm_ops/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,32 @@
import pytest

import lib.charms.hpc_libs.v0.slurm_ops as slurm
from lib.charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase
from lib.charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase, SnapPackageManager


@pytest.fixture
def slurmctld() -> SlurmManagerBase:
return SlurmManagerBase(ServiceType.SLURMCTLD)
return SlurmManagerBase(ServiceType.SLURMCTLD, SnapPackageManager())


@pytest.mark.order(1)
def test_install(slurmctld: SlurmManagerBase) -> None:
"""Install Slurm using the manager."""
slurm.install()
slurmctld.munge.generate_key()
slurmctld.ops.install()
slurmctld.munge.key.generate()

with open("/var/snap/slurm/common/etc/munge/munge.key", "rb") as f:
key: str = base64.b64encode(f.read()).decode()

assert key == slurmctld.munge.get_key()
assert key == slurmctld.munge.key.get()


@pytest.mark.order(2)
def test_rotate_key(slurmctld: SlurmManagerBase) -> None:
"""Test that the munge key can be rotated."""
old_key = slurmctld.munge.get_key()
slurmctld.munge.generate_key()
new_key = slurmctld.munge.get_key()
old_key = slurmctld.munge.key.get()
slurmctld.munge.key.generate()
new_key = slurmctld.munge.key.get()
assert old_key != new_key


Expand All @@ -58,14 +58,14 @@ def test_slurm_config(slurmctld: SlurmManagerBase) -> None:
@pytest.mark.order(4)
def test_enable_service(slurmctld: SlurmManagerBase) -> None:
"""Test that the slurmctl daemon can be enabled."""
slurmctld.enable()
assert slurmctld.active()
slurmctld.service.enable()
assert slurmctld.service.active()


@pytest.mark.order(5)
def test_version() -> None:
def test_version(slurmctld: SlurmManagerBase) -> None:
"""Test that the Slurm manager can report its version."""
version = slurm.version()
version = slurmctld.ops.version()

# We are interested in knowing that this does not return a falsy value (`None`, `''`, `[]`, etc.)
assert version
66 changes: 38 additions & 28 deletions tests/unit/test_slurm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from unittest.mock import patch

import charms.hpc_libs.v0.slurm_ops as slurm
from charms.hpc_libs.v0.slurm_ops import ServiceType, SlurmManagerBase, SlurmOpsError
from charms.hpc_libs.v0.slurm_ops import (
ServiceType,
SlurmManagerBase,
SlurmOpsError,
SnapPackageManager,
)

MUNGEKEY = b"1234567890"
MUNGEKEY_BASE64 = base64.b64encode(MUNGEKEY)
Expand Down Expand Up @@ -60,23 +65,33 @@

@patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output")
class TestSlurmOps(TestCase):

def test_format_key(self, _) -> None:
"""Test that `kebabize` properly formats slurm keys."""
self.assertEqual(slurm.format_key("CPUs"), "cpus")
self.assertEqual(slurm.format_key("AccountingStorageHost"), "accounting-storage-host")

def test_error_message(self, *_) -> None:
"""Test that `SlurmOpsError` stores the correct message."""
message = "error message!"
self.assertEqual(SlurmOpsError(message).message, message)


@patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output")
class TestSnapPackageManager(TestCase):
def setUp(self):
self.manager = SnapPackageManager()

def test_install(self, subcmd) -> None:
"""Test that `slurm_ops` calls the correct install command."""
slurm.install()
self.manager.install()
args = subcmd.call_args[0][0]
self.assertEqual(args[:3], ["snap", "install", "slurm"])
self.assertIn("--classic", args[3:]) # codespell:ignore

def test_version(self, subcmd) -> None:
"""Test that `slurm_ops` gets the correct version using the correct command."""
subcmd.return_value = SLURM_INFO.encode()
version = slurm.version()
version = self.manager.version()
args = subcmd.call_args[0][0]
self.assertEqual(args, ["snap", "info", "slurm"])
self.assertEqual(version, "23.11.7")
Expand All @@ -85,20 +100,15 @@ def test_version_not_installed(self, subcmd) -> None:
"""Test that `slurm_ops` throws when getting the installed version if the slurm snap is not installed."""
subcmd.return_value = SLURM_INFO_NOT_INSTALLED.encode()
with self.assertRaises(slurm.SlurmOpsError):
slurm.version()
self.manager.version()
args = subcmd.call_args[0][0]
self.assertEqual(args, ["snap", "info", "slurm"])

def test_call_error(self, subcmd) -> None:
"""Test that `slurm_ops` propagates errors when a command fails."""
subcmd.side_effect = subprocess.CalledProcessError(-1, cmd=[""], stderr=b"error")
with self.assertRaises(slurm.SlurmOpsError):
slurm.install()

def test_error_message(self, *_) -> None:
"""Test that `SlurmOpsError` stores the correct message."""
message = "error message!"
self.assertEqual(SlurmOpsError(message).message, message)
self.manager.install()


@patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output")
Expand All @@ -107,43 +117,43 @@ class SlurmOpsBase:

def test_config_name(self, *_) -> None:
"""Test that the config name is correctly set."""
self.assertEqual(self.manager._service.config_name, self.config_name)
self.assertEqual(self.manager.service.type().config_name, self.config_name)

def test_enable(self, subcmd, *_) -> None:
"""Test that the manager calls the correct enable command."""
self.manager.enable()
self.manager.service.enable()

args = subcmd.call_args[0][0]
self.assertEqual(
args, ["snap", "start", "--enable", f"slurm.{self.manager._service.value}"]
args, ["snap", "start", "--enable", f"slurm.{self.manager.service.type().value}"]
)

def test_disable(self, subcmd, *_) -> None:
"""Test that the manager calls the correct disable command."""
self.manager.disable()
self.manager.service.disable()

args = subcmd.call_args[0][0]
self.assertEqual(
args, ["snap", "stop", "--disable", f"slurm.{self.manager._service.value}"]
args, ["snap", "stop", "--disable", f"slurm.{self.manager.service.type().value}"]
)

def test_restart(self, subcmd, *_) -> None:
"""Test that the manager calls the correct restart command."""
self.manager.restart()
self.manager.service.restart()

args = subcmd.call_args[0][0]
self.assertEqual(args, ["snap", "restart", f"slurm.{self.manager._service.value}"])
self.assertEqual(args, ["snap", "restart", f"slurm.{self.manager.service.type().value}"])

def test_active(self, subcmd, *_) -> None:
"""Test that the manager can detect that a service is active."""
subcmd.return_value = SLURM_INFO.encode()
self.assertTrue(self.manager.active())
self.assertTrue(self.manager.service.active())

def test_active_not_installed(self, subcmd, *_) -> None:
"""Test that the manager throws an error when calling `active` if the snap is not installed."""
subcmd.return_value = SLURM_INFO_NOT_INSTALLED.encode()
with self.assertRaises(slurm.SlurmOpsError):
self.manager.active()
self.manager.service.active()
args = subcmd.call_args[0][0]
self.assertEqual(args, ["snap", "info", "slurm"])

Expand Down Expand Up @@ -194,21 +204,21 @@ def test_unset_config_all(self, subcmd) -> None:

def test_generate_munge_key(self, subcmd, *_) -> None:
"""Test that the manager calls the correct `mungectl` command."""
self.manager.munge.generate_key()
self.manager.munge.key.generate()
args = subcmd.call_args[0][0]
self.assertEqual(args, ["slurm.mungectl", "key", "generate"])

def test_set_munge_key(self, subcmd, *_) -> None:
"""Test that the manager sets the munge key with the correct command."""
self.manager.munge.set_key(MUNGEKEY_BASE64)
self.manager.munge.key.set(MUNGEKEY_BASE64)
args = subcmd.call_args[0][0]
# MUNGEKEY_BASE64 is piped to `stdin` to avoid exposure.
self.assertEqual(args, ["slurm.mungectl", "key", "set"])

def test_get_munge_key(self, subcmd, *_) -> None:
"""Test that the manager gets the munge key with the correct command."""
subcmd.return_value = MUNGEKEY_BASE64
key = self.manager.munge.get_key()
key = self.manager.munge.key.get()
args = subcmd.call_args[0][0]
self.assertEqual(args, ["slurm.mungectl", "key", "get"])
self.assertEqual(key, MUNGEKEY_BASE64)
Expand All @@ -229,14 +239,14 @@ def test_hostname(self, gethostname, *_) -> None:


parameters = [
(SlurmManagerBase(ServiceType.SLURMCTLD), "slurm"),
(SlurmManagerBase(ServiceType.SLURMD), "slurmd"),
(SlurmManagerBase(ServiceType.SLURMDBD), "slurmdbd"),
(SlurmManagerBase(ServiceType.SLURMRESTD), "slurmrestd"),
(SlurmManagerBase(ServiceType.SLURMCTLD, SnapPackageManager()), "slurm"),
(SlurmManagerBase(ServiceType.SLURMD, SnapPackageManager()), "slurmd"),
(SlurmManagerBase(ServiceType.SLURMDBD, SnapPackageManager()), "slurmdbd"),
(SlurmManagerBase(ServiceType.SLURMRESTD, SnapPackageManager()), "slurmrestd"),
]

for manager, config_name in parameters:
cls_name = f"Test{manager._service.value.capitalize()}Ops"
cls_name = f"Test{manager.service.type().name.capitalize()}Ops"
globals()[cls_name] = type(
cls_name,
(SlurmOpsBase, TestCase),
Expand Down

0 comments on commit 336e2b8

Please sign in to comment.