From 98b502e57583373d4e2e3ab1e65bb0f7c719b43e Mon Sep 17 00:00:00 2001 From: Gabriel Cocenza Date: Wed, 9 Oct 2024 16:21:40 -0300 Subject: [PATCH] Try to install NVIDIA driver if not present in the machine (#328) * Try to install NVIDIA driver if not present in the machine detect if NVIDIA driver is present by the /proc/driver/nvidia/version file and: * if present, doesn't try to install * if not present, try to install using the ubuntu-drivers pkg --- src/charm.py | 2 +- src/config.py | 1 + src/hw_tools.py | 93 +++++++++++++++++++++-- src/service.py | 83 +++++++++++--------- tests/unit/test_hw_tools.py | 147 +++++++++++++++++++++++++++++++++++- tests/unit/test_service.py | 122 ++++++++++++++---------------- 6 files changed, 336 insertions(+), 112 deletions(-) diff --git a/src/charm.py b/src/charm.py index 11900a41..ec8a8534 100755 --- a/src/charm.py +++ b/src/charm.py @@ -84,7 +84,7 @@ def exporters(self) -> List[BaseExporter]: exporters.append(SmartCtlExporter(self.model.config)) if stored_tools & DCGMExporter.hw_tools(): - exporters.append(DCGMExporter(self.charm_dir, self.model.config)) + exporters.append(DCGMExporter(self.model.config)) return exporters diff --git a/src/config.py b/src/config.py index a5fe935b..8af9b950 100644 --- a/src/config.py +++ b/src/config.py @@ -65,6 +65,7 @@ class HWTool(str, Enum): REDFISH = "redfish" SMARTCTL_EXPORTER = "smartctl-exporter" DCGM = "dcgm" + NVIDIA_DRIVER = "nvidia-driver" TPR_RESOURCES: t.Dict[HWTool, str] = { diff --git a/src/hw_tools.py b/src/hw_tools.py index a3ec7bee..3d8f73ab 100644 --- a/src/hw_tools.py +++ b/src/hw_tools.py @@ -61,6 +61,14 @@ def __init__(self, tool: HWTool, path: Path): self.message = f"Tool: {tool} path: {path} size is zero" +class ResourceInstallationError(Exception): + """Exception raised when a hardware tool installation fails.""" + + def __init__(self, tool: HWTool): + """Init.""" + super().__init__(f"Installation failed for tool: {tool}") + + def copy_to_snap_common_bin(source: Path, filename: str) -> None: """Copy file to $SNAP_COMMON/bin folder.""" Path(f"{SNAP_COMMON}/bin").mkdir(parents=False, exist_ok=True) @@ -184,50 +192,119 @@ class SnapStrategy(StrategyABC): channel: str @property - def snap(self) -> str: + def snap_name(self) -> str: """Snap name.""" return self._name.value + @property + def snap_common(self) -> Path: + """Snap common directory.""" + return Path(f"/var/snap/{self.snap_name}/common/") + + @property + def snap_client(self) -> snap.Snap: + """Return the snap client.""" + return snap.SnapCache()[self.snap_name] + def install(self) -> None: """Install the snap from a channel.""" try: - snap.add(self.snap, channel=self.channel) - logger.info("Installed %s from channel: %s", self.snap, self.channel) + snap.add(self.snap_name, channel=self.channel) + logger.info("Installed %s from channel: %s", self.snap_name, self.channel) # using the snap.SnapError will result into: # TypeError: catching classes that do not inherit from BaseException is not allowed except Exception as err: # pylint: disable=broad-except - logger.error("Failed to install %s from channel: %s: %s", self.snap, self.channel, err) + logger.error( + "Failed to install %s from channel: %s: %s", self.snap_name, self.channel, err + ) raise err def remove(self) -> None: """Remove the snap.""" try: - snap.remove([self.snap]) + snap.remove([self.snap_name]) # using the snap.SnapError will result into: # TypeError: catching classes that do not inherit from BaseException is not allowed except Exception as err: # pylint: disable=broad-except - logger.error("Failed to remove %s: %s", self.snap, err) + logger.error("Failed to remove %s: %s", self.snap_name, err) raise err def check(self) -> bool: """Check if all services are active.""" return all( service.get("active", False) - for service in snap.SnapCache()[self.snap].services.values() + for service in snap.SnapCache()[self.snap_name].services.values() ) class DCGMExporterStrategy(SnapStrategy): - """DCGM strategy class.""" + """DCGM exporter strategy class.""" _name = HWTool.DCGM + metric_file = Path.cwd() / "src/gpu_metrics/dcgm_metrics.csv" def __init__(self, channel: str) -> None: """Init.""" self.channel = channel + def install(self) -> None: + """Install the snap and the custom metrics.""" + super().install() + self._create_custom_metrics() + + def _create_custom_metrics(self) -> None: + logger.info("Creating a custom metrics file and configuring the DCGM snap to use it") + try: + shutil.copy(self.metric_file, self.snap_common) + self.snap_client.set({"dcgm-exporter-metrics-file": self.metric_file.name}) + self.snap_client.restart(reload=True) + except Exception as err: # pylint: disable=broad-except + logger.error("Failed to configure custom DCGM metrics: %s", err) + raise err + + +class NVIDIADriverStrategy(APTStrategyABC): + """NVIDIA driver strategy class.""" + + _name = HWTool.NVIDIA_DRIVER + pkg_pattern = r"nvidia(?:-[a-zA-Z-]*)?-(\d+)(?:-[a-zA-Z]*)?" + + def install(self) -> None: + """Install the NVIDIA driver if not present.""" + if Path("/proc/driver/nvidia/version").exists(): + logger.info("NVIDIA driver already installed in the machine") + return + + logger.info("Installing NVIDIA driver") + apt.add_package("ubuntu-drivers-common", update_cache=True) + + try: + # This can be changed to check_call and not rely in the output if this is fixed + # https://github.com/canonical/ubuntu-drivers-common/issues/106 + result = subprocess.check_output("ubuntu-drivers install --gpgpu".split(), text=True) + + except subprocess.CalledProcessError as err: + logger.error("Failed to install the NVIDIA driver: %s", err) + raise err + + if "No drivers found for installation" in result: + logger.warning( + "No drivers for the NVIDIA GPU were found. Manual installation is necessary" + ) + raise ResourceInstallationError(self._name) + + logger.info("NVIDIA driver installed") + + def remove(self) -> None: + """Drivers shouldn't be removed by the strategy.""" + return None + + def check(self) -> bool: + """Check if driver was installed.""" + return Path("/proc/driver/nvidia/version").exists() + class SmartCtlExporterStrategy(SnapStrategy): """SmartCtl strategy class.""" diff --git a/src/service.py b/src/service.py index d71f4e9e..38b30a0a 100644 --- a/src/service.py +++ b/src/service.py @@ -1,12 +1,11 @@ """Exporter service helper.""" import os -import shutil from abc import ABC, abstractmethod from logging import getLogger from pathlib import Path from time import sleep -from typing import Any, Dict, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union from charms.operator_libs_linux.v1 import systemd from charms.operator_libs_linux.v2 import snap @@ -22,7 +21,13 @@ HWTool, ) from hardware import get_bmc_address -from hw_tools import DCGMExporterStrategy, SmartCtlExporterStrategy, SnapStrategy +from hw_tools import ( + APTStrategyABC, + DCGMExporterStrategy, + NVIDIADriverStrategy, + SmartCtlExporterStrategy, + SnapStrategy, +) logger = getLogger(__name__) @@ -315,7 +320,7 @@ class SnapExporter(BaseExporter): exporter_name: str channel: str port: int - strategy: SnapStrategy + strategies: List[Union[SnapStrategy, APTStrategyABC]] def __init__(self, config: ConfigData): """Init.""" @@ -324,7 +329,7 @@ def __init__(self, config: ConfigData): @property def snap_client(self) -> snap.Snap: """Return the snap client.""" - return snap.SnapCache()[self.strategy.snap] + return snap.SnapCache()[self.exporter_name] @staticmethod def hw_tools() -> Set[HWTool]: @@ -337,10 +342,12 @@ def install(self) -> bool: Returns true if the install is successful, false otherwise. """ try: - self.strategy.install() + for strategy in self.strategies: + strategy.install() self.enable_and_start() return self.snap_client.present is True - except Exception: # pylint: disable=broad-except + except Exception as err: # pylint: disable=broad-except + logger.error("Failed to install %s: %s", strategy.name, err) return False def uninstall(self) -> bool: @@ -349,12 +356,13 @@ def uninstall(self) -> bool: Returns true if the uninstall is successful, false otherwise. """ try: - self.strategy.remove() + for strategy in self.strategies: + strategy.remove() # using the snap.SnapError will result into: # TypeError: catching classes that do not inherit from BaseException is not allowed except Exception as err: # pylint: disable=broad-except - logger.error("Failed to remove %s: %s", self.strategy.snap, err) + logger.error("Failed to remove %s: %s", strategy.name, err) return False return self.snap_client.present is False @@ -379,7 +387,7 @@ def set(self, snap_config: dict) -> bool: try: self.snap_client.set(snap_config, typed=True) except snap.SnapError as err: - logger.error("Failed to update snap configs %s: %s", self.strategy.snap, err) + logger.error("Failed to update snap configs %s: %s", self.exporter_name, err) return False return True @@ -388,14 +396,22 @@ def check_health(self) -> bool: Returns true if the service is healthy, false otherwise. """ - return self.strategy.check() + return all(strategy.check() for strategy in self.strategies) def configure(self) -> bool: """Set the necessary exporter configurations or change snap channel. Returns true if the configure is successful, false otherwise. """ - return self.install() + for strategy in self.strategies: + if isinstance(strategy, SnapStrategy): + try: + # refresh the snap for a new channel if necessary + strategy.install() + except Exception as err: # pylint: disable=broad-except + logger.error("Failed to configure %s: %s", self.exporter_name, err) + return False + return True class DCGMExporter(SnapExporter): @@ -403,38 +419,33 @@ class DCGMExporter(SnapExporter): exporter_name: str = "dcgm" port: int = 9400 - snap_common: Path = Path("/var/snap/dcgm/common/") - metric_config: str = "dcgm-exporter-metrics-file" - def __init__(self, charm_dir: Path, config: ConfigData): + def __init__(self, config: ConfigData): """Init.""" - self.strategy = DCGMExporterStrategy(str(config["dcgm-snap-channel"])) - self.charm_dir = charm_dir - self.metrics_file = self.charm_dir / "src/gpu_metrics/dcgm_metrics.csv" - self.metric_config_value = self.metrics_file.name + self.strategies = [ + DCGMExporterStrategy(str(config["dcgm-snap-channel"])), + NVIDIADriverStrategy(), + ] super().__init__(config) - def install(self) -> bool: - """Install the DCGM exporter and configure custom metrics.""" - if not super().install(): - return False - - logger.info("Creating a custom metrics file and configuring the DCGM snap to use it") - try: - shutil.copy(self.metrics_file, self.snap_common) - self.snap_client.set({self.metric_config: self.metric_config_value}) - self.snap_client.restart(reload=True) - except Exception as err: # pylint: disable=broad-except - logger.error("Failed to configure custom DCGM metrics: %s", err) - return False - - return True - @staticmethod def hw_tools() -> Set[HWTool]: """Return hardware tools to watch.""" return {HWTool.DCGM} + def validate_exporter_configs(self) -> Tuple[bool, str]: + """Validate the if the DCGM exporter is able to run.""" + valid, msg = super().validate_exporter_configs() + if not valid: + return False, msg + + if not NVIDIADriverStrategy().check(): + return ( + False, + "Failed to communicate with NVIDIA driver. See more details in the logs", + ) + return valid, msg + class SmartCtlExporter(SnapExporter): """A class representing the smartctl exporter and the metric endpoints.""" @@ -445,7 +456,7 @@ def __init__(self, config: ConfigData) -> None: """Initialize the SmartctlExporter class.""" self.port = int(config["smartctl-exporter-port"]) self.log_level = str(config["exporter-log-level"]) - self.strategy = SmartCtlExporterStrategy(str(config["smartctl-exporter-snap-channel"])) + self.strategies = [SmartCtlExporterStrategy(str(config["smartctl-exporter-snap-channel"]))] super().__init__(config) @staticmethod diff --git a/tests/unit/test_hw_tools.py b/tests/unit/test_hw_tools.py index 0def3264..55d29dc3 100644 --- a/tests/unit/test_hw_tools.py +++ b/tests/unit/test_hw_tools.py @@ -22,13 +22,16 @@ from config import SNAP_COMMON, TOOLS_DIR, TPR_RESOURCES, HWTool, StorageVendor, SystemVendor from hw_tools import ( APTStrategyABC, + DCGMExporterStrategy, HWToolHelper, IPMIDCMIStrategy, IPMISELStrategy, IPMISENSORStrategy, + NVIDIADriverStrategy, PercCLIStrategy, ResourceChecksumError, ResourceFileSizeZeroError, + ResourceInstallationError, SAS2IRCUStrategy, SAS3IRCUStrategy, SnapStrategy, @@ -1037,7 +1040,7 @@ def mock_snap_lib(): def test_snap_strategy_name(snap_exporter): - assert snap_exporter.snap == "my-snap" + assert snap_exporter.snap_name == "my-snap" def test_snap_strategy_channel(snap_exporter): @@ -1046,7 +1049,9 @@ def test_snap_strategy_channel(snap_exporter): def test_snap_strategy_install_success(snap_exporter, mock_snap_lib): snap_exporter.install() - mock_snap_lib.add.assert_called_once_with(snap_exporter.snap, channel=snap_exporter.channel) + mock_snap_lib.add.assert_called_once_with( + snap_exporter.snap_name, channel=snap_exporter.channel + ) def test_snap_strategy_install_fail(snap_exporter, mock_snap_lib): @@ -1058,7 +1063,7 @@ def test_snap_strategy_install_fail(snap_exporter, mock_snap_lib): def test_snap_strategy_remove_success(snap_exporter, mock_snap_lib): snap_exporter.remove() - mock_snap_lib.remove.assert_called_once_with([snap_exporter.snap]) + mock_snap_lib.remove.assert_called_once_with([snap_exporter.snap_name]) def test_snap_strategy_remove_fail(snap_exporter, mock_snap_lib): @@ -1143,6 +1148,142 @@ def test_snap_strategy_check(snap_exporter, mock_snap_lib, services, expected): assert snap_exporter.check() is expected +@pytest.fixture +def mock_check_output(): + with mock.patch("hw_tools.subprocess.check_output") as mocked_check_output: + yield mocked_check_output + + +@pytest.fixture +def mock_check_call(): + with mock.patch("hw_tools.subprocess.check_call") as mocked_check_call: + yield mocked_check_call + + +@pytest.fixture +def mock_apt_lib(): + with mock.patch("hw_tools.apt") as mocked_apt_lib: + yield mocked_apt_lib + + +@pytest.fixture +def mock_path(): + with mock.patch("hw_tools.Path") as mocked_path: + yield mocked_path + + +@pytest.fixture +def mock_shutil_copy(): + with mock.patch("hw_tools.shutil.copy") as mocked_copy: + yield mocked_copy + + +@pytest.fixture +def nvidia_driver_strategy(mock_check_output, mock_apt_lib, mock_path, mock_check_call): + strategy = NVIDIADriverStrategy() + strategy.installed_pkgs = mock_path + yield strategy + + +@pytest.fixture +def dcgm_exporter_strategy(mock_snap_lib, mock_shutil_copy): + yield DCGMExporterStrategy("latest/stable") + + +@mock.patch("hw_tools.DCGMExporterStrategy._create_custom_metrics") +def test_dcgm_exporter_install(mock_custom_metrics, dcgm_exporter_strategy): + assert dcgm_exporter_strategy.install() is None + mock_custom_metrics.assert_called_once() + + +def test_dcgm_create_custom_metrics(dcgm_exporter_strategy, mock_shutil_copy, mock_snap_lib): + assert dcgm_exporter_strategy._create_custom_metrics() is None + mock_shutil_copy.assert_called_once_with( + Path.cwd() / "src/gpu_metrics/dcgm_metrics.csv", Path("/var/snap/dcgm/common") + ) + dcgm_exporter_strategy.snap_client.set.assert_called_once_with( + {"dcgm-exporter-metrics-file": "dcgm_metrics.csv"} + ) + dcgm_exporter_strategy.snap_client.restart.assert_called_once_with(reload=True) + + +def test_dcgm_create_custom_metrics_copy_fail( + dcgm_exporter_strategy, mock_shutil_copy, mock_snap_lib +): + mock_shutil_copy.side_effect = FileNotFoundError + with pytest.raises(FileNotFoundError): + dcgm_exporter_strategy._create_custom_metrics() + + dcgm_exporter_strategy.snap_client.set.assert_not_called() + dcgm_exporter_strategy.snap_client.restart.assert_not_called() + + +def test_nvidia_driver_strategy_install_success( + mock_path, mock_check_output, mock_apt_lib, nvidia_driver_strategy +): + nvidia_version = mock.MagicMock() + nvidia_version.exists.return_value = False + mock_path.return_value = nvidia_version + + nvidia_driver_strategy.install() + + mock_apt_lib.add_package.assert_called_once_with("ubuntu-drivers-common", update_cache=True) + mock_check_output.assert_called_once_with("ubuntu-drivers install --gpgpu".split(), text=True) + + +def test_install_nvidia_drivers_already_installed( + mock_path, mock_apt_lib, nvidia_driver_strategy, mock_check_output +): + nvidia_version = mock.MagicMock() + nvidia_version.exists.return_value = True + mock_path.return_value = nvidia_version + + nvidia_driver_strategy.install() + + mock_apt_lib.add_package.assert_not_called() + mock_check_output.assert_not_called() + + +def test_install_nvidia_drivers_subprocess_exception( + mock_path, mock_check_output, mock_apt_lib, nvidia_driver_strategy +): + nvidia_version = mock.MagicMock() + nvidia_version.exists.return_value = False + mock_path.return_value = nvidia_version + mock_check_output.side_effect = subprocess.CalledProcessError(1, []) + + with pytest.raises(subprocess.CalledProcessError): + nvidia_driver_strategy.install() + + mock_apt_lib.add_package.assert_called_once_with("ubuntu-drivers-common", update_cache=True) + + +def test_install_nvidia_drivers_no_drivers_found( + mock_path, mock_check_output, mock_apt_lib, nvidia_driver_strategy +): + nvidia_version = mock.MagicMock() + nvidia_version.exists.return_value = False + mock_path.return_value = nvidia_version + mock_check_output.return_value = "No drivers found for installation" + + with pytest.raises(ResourceInstallationError): + nvidia_driver_strategy.install() + + mock_apt_lib.add_package.assert_called_once_with("ubuntu-drivers-common", update_cache=True) + + +def test_nvidia_strategy_remove(nvidia_driver_strategy): + assert nvidia_driver_strategy.remove() is None + + +@pytest.mark.parametrize("present, expected", [(True, True), (False, False)]) +def test_nvidia_strategy_check(nvidia_driver_strategy, mock_path, present, expected): + nvidia_version = mock.MagicMock() + nvidia_version.exists.return_value = present + mock_path.return_value = nvidia_version + assert nvidia_driver_strategy.check() is expected + + @mock.patch("hw_tools.Path.unlink") @mock.patch("hw_tools.Path.exists") @mock.patch("hw_tools.shutil") diff --git a/tests/unit/test_service.py b/tests/unit/test_service.py index a0fb8835..3c75577b 100644 --- a/tests/unit/test_service.py +++ b/tests/unit/test_service.py @@ -721,19 +721,18 @@ class TestDCGMSnapExporter(unittest.TestCase): def setUp(self) -> None: """Set up harness for each test case.""" snap_lib_patcher = mock.patch.object(service, "snap") - shutil_lib_patcher = mock.patch.object(service, "shutil") + self.mock_snap = snap_lib_patcher.start() - self.mock_shutil = shutil_lib_patcher.start() self.addCleanup(snap_lib_patcher.stop) - self.addCleanup(shutil_lib_patcher.stop) - - search_path = pathlib.Path(f"{__file__}/../../..").resolve() - self.mock_config = { - "dcgm-snap-channel": "latest/stable", - } - self.exporter = service.DCGMExporter(search_path, self.mock_config) - self.exporter.strategy = mock.MagicMock() + self.exporter = service.DCGMExporter( + { + "dcgm-snap-channel": "latest/stable", + } + ) + self.snap_strategy = mock.MagicMock(spec=service.DCGMExporterStrategy) + self.nvidia_strategy = mock.MagicMock(spec=service.NVIDIADriverStrategy) + self.exporter.strategies = [self.snap_strategy, self.nvidia_strategy] def test_exporter_name(self): self.assertEqual(self.exporter.exporter_name, "dcgm") @@ -741,39 +740,26 @@ def test_exporter_name(self): def test_hw_tools(self): self.assertEqual(self.exporter.hw_tools(), {HWTool.DCGM}) - def test_install_failed(self): - self.exporter.snap_client.present = False - - exporter_install_ok = self.exporter.install() - - self.exporter.strategy.install.assert_called() - self.mock_shutil.copy.assert_not_called() - self.assertFalse(exporter_install_ok) - - def test_install_success(self): - self.exporter.snap_client.present = True - - exporter_install_ok = self.exporter.install() + @mock.patch("service.NVIDIADriverStrategy.check", return_value=True) + def test_validate_exporter_configs_success(self, _): + valid, msg = self.exporter.validate_exporter_configs() + self.assertTrue(valid) + self.assertEqual(msg, "Exporter config is valid.") - self.exporter.strategy.install.assert_called() - self.mock_shutil.copy.assert_called_with( - self.exporter.metrics_file, self.exporter.snap_common - ) - self.exporter.snap_client.set.assert_called_with( - {self.exporter.metric_config: self.exporter.metric_config_value} + @mock.patch("service.NVIDIADriverStrategy.check", return_value=False) + def test_validate_exporter_configs_fails(self, _): + valid, msg = self.exporter.validate_exporter_configs() + self.assertFalse(valid) + self.assertEqual( + msg, "Failed to communicate with NVIDIA driver. See more details in the logs" ) - self.exporter.snap_client.restart.assert_called_with(reload=True) - self.assertTrue(exporter_install_ok) - - def test_install_metrics_copy_fail(self): - self.exporter.snap_client.present = True - self.mock_shutil.copy.side_effect = FileNotFoundError - exporter_install_ok = self.exporter.install() - - self.exporter.strategy.install.assert_called() - self.exporter.snap_client.restart.assert_not_called() - self.assertFalse(exporter_install_ok) + @mock.patch.object(service.BaseExporter, "validate_exporter_configs") + def test_validate_exporter_configs_fails_parent(self, mock_parent_validate): + mock_parent_validate.return_value = False, "Invalid config: exporter's port" + valid, msg = self.exporter.validate_exporter_configs() + self.assertFalse(valid) + self.assertEqual(msg, "Invalid config: exporter's port") class TestWriteToFile(unittest.TestCase): @@ -845,25 +831,27 @@ def test_write_to_file_not_a_directory_error(self, mock_open): @pytest.fixture def snap_exporter(): - my_strategy = mock.MagicMock(spec=service.SnapStrategy) + my_snap_strategy = mock.MagicMock(spec=service.SnapStrategy) + my_apt_strategy = mock.MagicMock(spec=service.APTStrategyABC) class MySnapExporter(service.SnapExporter): exporter_name = "my-exporter" channel = "my-channel" - strategy = my_strategy - - mock_config = { - "dcgm-snap-channel": "latest/stable", - } + strategies = [my_snap_strategy, my_apt_strategy] with mock.patch("service.snap.SnapCache"): - exporter = MySnapExporter(mock_config) + exporter = MySnapExporter( + { + "dcgm-snap-channel": "latest/stable", + } + ) exporter.snap_client.services = {"service1": {}, "service2": {}} yield exporter - my_strategy.reset_mock() + my_snap_strategy.reset_mock() + my_apt_strategy.reset_mock() def test_snap_exporter_hw_tools(snap_exporter): @@ -871,16 +859,16 @@ def test_snap_exporter_hw_tools(snap_exporter): assert snap_exporter.hw_tools() == set() -def test_snap_exporter_install(snap_exporter): - snap_exporter.strategy.install.return_value = True +def test_snap_exporter_install_success(snap_exporter): snap_exporter.snap_client.present = True assert snap_exporter.install() is True - snap_exporter.strategy.install.assert_called_once() + for strategy in snap_exporter.strategies: + strategy.install.assert_called_once() def test_snap_exporter_install_fail(snap_exporter): - snap_exporter.strategy.install.side_effect = ValueError + snap_exporter.strategies[0].install.side_effect = ValueError assert snap_exporter.install() is False @@ -889,11 +877,12 @@ def test_snap_exporter_uninstall(snap_exporter): snap_exporter.snap_client.present = False assert snap_exporter.uninstall() is True - snap_exporter.strategy.remove.assert_called_once() + for strategy in snap_exporter.strategies: + strategy.remove.assert_called_once() def test_snap_exporter_uninstall_fail(snap_exporter): - snap_exporter.strategy.remove.side_effect = ValueError + snap_exporter.strategies[0].remove.side_effect = ValueError assert snap_exporter.uninstall() is False @@ -902,7 +891,8 @@ def test_snap_exporter_uninstall_present(snap_exporter): snap_exporter.snap_client.present = True assert snap_exporter.uninstall() is False - snap_exporter.strategy.remove.assert_called_once() + for strategy in snap_exporter.strategies: + strategy.remove.assert_called_once() def test_snap_exporter_enable_and_start(snap_exporter): @@ -935,20 +925,25 @@ def test_snap_exporter_set_failed(snap_exporter): def test_snap_exporter_check_health(snap_exporter): snap_exporter.check_health() - snap_exporter.strategy.check.assert_called_once() + for strategy in snap_exporter.strategies: + strategy.check.assert_called_once() + +@mock.patch("service.isinstance", return_value=True) +def test_snap_exporter_configure(_, snap_exporter): + assert snap_exporter.configure() is True + for strategy in snap_exporter.strategies: + strategy.install.assert_called_once() -@pytest.mark.parametrize("install_result, expected_result", [(True, True), (False, False)]) -@mock.patch("service.SnapExporter.install") -def test_snap_exporter_configure(mock_install, snap_exporter, install_result, expected_result): - mock_install.return_value = install_result - assert snap_exporter.configure() is expected_result - mock_install.assert_called_once() +@mock.patch("service.isinstance", return_value=True) +def test_snap_exporter_configure_exception(_, snap_exporter): + snap_exporter.strategies[0].install.side_effect = snap.SnapError + assert snap_exporter.configure() is False @pytest.mark.parametrize("result, expected_result", [(True, True), (False, False)]) -@mock.patch("service.SnapExporter.install") +@mock.patch("service.SmartCtlExporterStrategy.install") @mock.patch("service.SnapExporter.set") def test_smartctl_exporter_configure(mock_set, mock_install, result, expected_result): mock_config = { @@ -956,7 +951,6 @@ def test_smartctl_exporter_configure(mock_set, mock_install, result, expected_re "exporter-log-level": "info", "smartctl-exporter-snap-channel": "latest/stable", } - mock_set.return_value = result mock_install.return_value = result exporter = service.SmartCtlExporter(mock_config)