diff --git a/lib/charms/hpc_libs/v0/slurm_ops.py b/lib/charms/hpc_libs/v0/slurm_ops.py index 949125f..5247f36 100644 --- a/lib/charms/hpc_libs/v0/slurm_ops.py +++ b/lib/charms/hpc_libs/v0/slurm_ops.py @@ -486,7 +486,8 @@ def edit(cls: Type["_SlurmConfig"], file: Union[str, os.PathLike]) -> "_SlurmCon Args: file: Path to configuration file to edit. """ - return cls(slurmconfig.edit(file)) + with slurmconfig.edit(file) as config: + yield cls(config) @classmethod def from_dict(cls, data: Dict[str, Any]): @@ -534,7 +535,8 @@ def edit(cls: Type["_SlurmdbdConfig"], file: Union[str, os.PathLike]) -> "_Slurm Args: file: Path to configuration file to edit. """ - return cls(slurmdbdconfig.edit(file)) + with slurmdbdconfig.edit(file) as config: + yield cls(config) @classmethod def from_dict(cls, data: Dict[str, Any]): @@ -559,9 +561,6 @@ def update(self, other: "_SlurmdbdConfig") -> None: return self._inner.update(other._inner) -T = TypeVar("T", bound=_SlurmBaseConfigManager) - - class _SnapFsConfigManager(ConfigManager): """Control configuration of a Slurm component using a file.""" @@ -575,6 +574,9 @@ def get(self, key: Optional[str] = None) -> Any: """Get specific configuration value for Slurm component.""" config = self._class.load(self._file).dict() + if key is None: + return config + try: for prop in key.split("."): if isinstance(config, dict): @@ -598,15 +600,15 @@ def set(self, config: Mapping[str, Any]) -> None: def unset(self, *keys: str) -> None: """Unset configuration for Slurm component.""" if len(keys) == 0: - Path(self._file).unlink() + Path(self._file).open("w").close() return with self._class.edit(self._file) as current_config: for key in keys: - rest, last = key.rsplit(".", 1) - base = current_config + *rest, last = key.split(".") + base = current_config._inner try: - for prop in rest.split("."): + for prop in rest: base = getattr(base, prop) del base[last] except ModelError as e: diff --git a/tests/unit/test_slurm_ops.py b/tests/unit/test_slurm_ops.py index 93352dc..ac3a379 100644 --- a/tests/unit/test_slurm_ops.py +++ b/tests/unit/test_slurm_ops.py @@ -18,6 +18,7 @@ SnapManager, ) from pyfakefs.fake_filesystem_unittest import TestCase as FsTestCase +from slurmutils.editors import slurmconfig MUNGEKEY = b"1234567890" MUNGEKEY_BASE64 = base64.b64encode(MUNGEKEY) @@ -284,44 +285,55 @@ def setUp(self): def test_get_options(self, subcmd) -> None: """Test that the manager correctly collects all requested configuration options.""" - value = self.manager.config.get_options( + options = self.manager.config.get_options( "SlurmdLogFile", "Nodes.juju-c9fc6f-2.NodeAddr", "DownNodes.0.State" ) - print(value) - - # def test_get_config(self, subcmd, *_) -> None: - # """Test that the manager calls the correct `snap get ...` command.""" - # subcmd.return_value = '{"%s.key": "value"}' % self.config_name - # value = self.manager.config.get("key") - # args = subcmd.call_args[0][0] - # self.assertEqual(args, ["snap", "get", "-d", "slurm", f"{self.config_name}.key"]) - # self.assertEqual(value, "value") + self.assertEqual(options["SlurmdLogFile"], "/var/log/slurm/slurmd.log") + self.assertEqual(options["NodeAddr"], "10.152.28.48") + self.assertEqual(options["State"], "DOWN") + + def test_get_config(self, subcmd, *_) -> None: + """Test that the manager gets the correct configuration values.""" + self.assertEqual(self.manager.config.get("InactiveLimit"), "120") + self.assertEqual(self.manager.config.get("Nodes.juju-c9fc6f-2.RealMemory"), "1000") + self.assertEqual(self.manager.config.get("DownNodes.0.DownNodes"), ["juju-c9fc6f-5"]) + self.assertEqual( + self.manager.config.get(), slurmconfig.loads(self.EXAMPLE_SLURM_CONF).dict() + ) - # def test_get_config_all(self, subcmd) -> None: - # """Test that manager calls the correct `snap get ...` with no arguments given.""" - # subcmd.return_value = '{"%s": "value"}' % self.config_name - # value = self.manager.config.get() - # args = subcmd.call_args[0][0] - # self.assertEqual(args, ["snap", "get", "-d", "slurm", self.config_name]) - # self.assertEqual(value, "value") + def test_set_config(self, subcmd, *_) -> None: + """Test that the manager sets the correct configuration values.""" + self.manager.config.set( + { + "SlurmctldPort": "8081", + "Nodes": { + "juju-c9fc6f-2": { + "CPUs": "10", + }, + "juju-c9fc6f-20": { + "CPUs": "1", + }, + }, + "DownNodes": [ + {"DownNodes": ["juju-c9fc6f-3"], "State": "DOWN", "Reason": "New nodes"} + ], + } + ) + self.assertEqual(self.manager.config.get("SlurmctldPort"), "8081") + self.assertEqual(self.manager.config.get("Nodes.juju-c9fc6f-2.CPUs"), "10") + self.assertEqual(self.manager.config.get("Nodes.juju-c9fc6f-20.CPUs"), "1") + self.assertEqual(self.manager.config.get("DownNodes.1.DownNodes"), ["juju-c9fc6f-3"]) - # def test_set_config(self, subcmd, *_) -> None: - # """Test that the manager calls the correct `snap set ...` command.""" - # self.manager.config.set({"key": "value"}) - # args = subcmd.call_args[0][0] - # self.assertEqual(args, ["snap", "set", "slurm", f'{self.config_name}.key="value"']) + def test_unset_config(self, subcmd) -> None: + """Test that the manager unsets the correct configuration values.""" + self.manager.config.unset("ReturnToService", "Nodes.juju-c9fc6f-2", "DownNodes.0.Reason") - # def test_unset_config(self, subcmd) -> None: - # """Test that the manager calls the correct `snap unset ...` command.""" - # self.manager.config.unset("key") - # args = subcmd.call_args[0][0] - # self.assertEqual(args, ["snap", "unset", "slurm", f"{self.config_name}.key"]) + self.assertEqual(self.manager.config.get("ReturnToService"), None) + self.assertEqual(self.manager.config.get("Nodes.juju-c9fc6f-2"), None) + self.assertEqual(self.manager.config.get("DownNodes.0.Reason"), None) - # def test_unset_config_all(self, subcmd) -> None: - # """Test the manager calls the correct `snap unset ...` with no arguments given.""" - # self.manager.config.unset() - # args = subcmd.call_args[0][0] - # self.assertEqual(args, ["snap", "unset", "slurm", self.config_name]) + self.manager.config.unset() + self.assertEqual(self.manager.config.get(), slurmconfig.SlurmConfig().dict()) @patch("charms.hpc_libs.v0.slurm_ops.subprocess.check_output")