Skip to content

Commit

Permalink
Try to install NVIDIA driver if not present in the machine (#328)
Browse files Browse the repository at this point in the history
* Try to install NVIDIA driver if not present in the machine

detect if NVIDIA driver is present by the /proc/driver/nvidia/version
file and:

* if present, doesn't try to install
* if not present, try to install using the ubuntu-drivers pkg
  • Loading branch information
gabrielcocenza authored Oct 9, 2024
1 parent 1e9281d commit 98b502e
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 112 deletions.
2 changes: 1 addition & 1 deletion src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def exporters(self) -> List[BaseExporter]:
exporters.append(SmartCtlExporter(self.model.config))

if stored_tools & DCGMExporter.hw_tools():
exporters.append(DCGMExporter(self.charm_dir, self.model.config))
exporters.append(DCGMExporter(self.model.config))

return exporters

Expand Down
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class HWTool(str, Enum):
REDFISH = "redfish"
SMARTCTL_EXPORTER = "smartctl-exporter"
DCGM = "dcgm"
NVIDIA_DRIVER = "nvidia-driver"


TPR_RESOURCES: t.Dict[HWTool, str] = {
Expand Down
93 changes: 85 additions & 8 deletions src/hw_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def __init__(self, tool: HWTool, path: Path):
self.message = f"Tool: {tool} path: {path} size is zero"


class ResourceInstallationError(Exception):
"""Exception raised when a hardware tool installation fails."""

def __init__(self, tool: HWTool):
"""Init."""
super().__init__(f"Installation failed for tool: {tool}")


def copy_to_snap_common_bin(source: Path, filename: str) -> None:
"""Copy file to $SNAP_COMMON/bin folder."""
Path(f"{SNAP_COMMON}/bin").mkdir(parents=False, exist_ok=True)
Expand Down Expand Up @@ -184,50 +192,119 @@ class SnapStrategy(StrategyABC):
channel: str

@property
def snap(self) -> str:
def snap_name(self) -> str:
"""Snap name."""
return self._name.value

@property
def snap_common(self) -> Path:
"""Snap common directory."""
return Path(f"/var/snap/{self.snap_name}/common/")

@property
def snap_client(self) -> snap.Snap:
"""Return the snap client."""
return snap.SnapCache()[self.snap_name]

def install(self) -> None:
"""Install the snap from a channel."""
try:
snap.add(self.snap, channel=self.channel)
logger.info("Installed %s from channel: %s", self.snap, self.channel)
snap.add(self.snap_name, channel=self.channel)
logger.info("Installed %s from channel: %s", self.snap_name, self.channel)

# using the snap.SnapError will result into:
# TypeError: catching classes that do not inherit from BaseException is not allowed
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to install %s from channel: %s: %s", self.snap, self.channel, err)
logger.error(
"Failed to install %s from channel: %s: %s", self.snap_name, self.channel, err
)
raise err

def remove(self) -> None:
"""Remove the snap."""
try:
snap.remove([self.snap])
snap.remove([self.snap_name])

# using the snap.SnapError will result into:
# TypeError: catching classes that do not inherit from BaseException is not allowed
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to remove %s: %s", self.snap, err)
logger.error("Failed to remove %s: %s", self.snap_name, err)
raise err

def check(self) -> bool:
"""Check if all services are active."""
return all(
service.get("active", False)
for service in snap.SnapCache()[self.snap].services.values()
for service in snap.SnapCache()[self.snap_name].services.values()
)


class DCGMExporterStrategy(SnapStrategy):
"""DCGM strategy class."""
"""DCGM exporter strategy class."""

_name = HWTool.DCGM
metric_file = Path.cwd() / "src/gpu_metrics/dcgm_metrics.csv"

def __init__(self, channel: str) -> None:
"""Init."""
self.channel = channel

def install(self) -> None:
"""Install the snap and the custom metrics."""
super().install()
self._create_custom_metrics()

def _create_custom_metrics(self) -> None:
logger.info("Creating a custom metrics file and configuring the DCGM snap to use it")
try:
shutil.copy(self.metric_file, self.snap_common)
self.snap_client.set({"dcgm-exporter-metrics-file": self.metric_file.name})
self.snap_client.restart(reload=True)
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to configure custom DCGM metrics: %s", err)
raise err


class NVIDIADriverStrategy(APTStrategyABC):
"""NVIDIA driver strategy class."""

_name = HWTool.NVIDIA_DRIVER
pkg_pattern = r"nvidia(?:-[a-zA-Z-]*)?-(\d+)(?:-[a-zA-Z]*)?"

def install(self) -> None:
"""Install the NVIDIA driver if not present."""
if Path("/proc/driver/nvidia/version").exists():
logger.info("NVIDIA driver already installed in the machine")
return

logger.info("Installing NVIDIA driver")
apt.add_package("ubuntu-drivers-common", update_cache=True)

try:
# This can be changed to check_call and not rely in the output if this is fixed
# https://github.com/canonical/ubuntu-drivers-common/issues/106
result = subprocess.check_output("ubuntu-drivers install --gpgpu".split(), text=True)

except subprocess.CalledProcessError as err:
logger.error("Failed to install the NVIDIA driver: %s", err)
raise err

if "No drivers found for installation" in result:
logger.warning(
"No drivers for the NVIDIA GPU were found. Manual installation is necessary"
)
raise ResourceInstallationError(self._name)

logger.info("NVIDIA driver installed")

def remove(self) -> None:
"""Drivers shouldn't be removed by the strategy."""
return None

def check(self) -> bool:
"""Check if driver was installed."""
return Path("/proc/driver/nvidia/version").exists()


class SmartCtlExporterStrategy(SnapStrategy):
"""SmartCtl strategy class."""
Expand Down
83 changes: 47 additions & 36 deletions src/service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Exporter service helper."""

import os
import shutil
from abc import ABC, abstractmethod
from logging import getLogger
from pathlib import Path
from time import sleep
from typing import Any, Dict, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from charms.operator_libs_linux.v1 import systemd
from charms.operator_libs_linux.v2 import snap
Expand All @@ -22,7 +21,13 @@
HWTool,
)
from hardware import get_bmc_address
from hw_tools import DCGMExporterStrategy, SmartCtlExporterStrategy, SnapStrategy
from hw_tools import (
APTStrategyABC,
DCGMExporterStrategy,
NVIDIADriverStrategy,
SmartCtlExporterStrategy,
SnapStrategy,
)

logger = getLogger(__name__)

Expand Down Expand Up @@ -315,7 +320,7 @@ class SnapExporter(BaseExporter):
exporter_name: str
channel: str
port: int
strategy: SnapStrategy
strategies: List[Union[SnapStrategy, APTStrategyABC]]

def __init__(self, config: ConfigData):
"""Init."""
Expand All @@ -324,7 +329,7 @@ def __init__(self, config: ConfigData):
@property
def snap_client(self) -> snap.Snap:
"""Return the snap client."""
return snap.SnapCache()[self.strategy.snap]
return snap.SnapCache()[self.exporter_name]

@staticmethod
def hw_tools() -> Set[HWTool]:
Expand All @@ -337,10 +342,12 @@ def install(self) -> bool:
Returns true if the install is successful, false otherwise.
"""
try:
self.strategy.install()
for strategy in self.strategies:
strategy.install()
self.enable_and_start()
return self.snap_client.present is True
except Exception: # pylint: disable=broad-except
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to install %s: %s", strategy.name, err)
return False

def uninstall(self) -> bool:
Expand All @@ -349,12 +356,13 @@ def uninstall(self) -> bool:
Returns true if the uninstall is successful, false otherwise.
"""
try:
self.strategy.remove()
for strategy in self.strategies:
strategy.remove()

# using the snap.SnapError will result into:
# TypeError: catching classes that do not inherit from BaseException is not allowed
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to remove %s: %s", self.strategy.snap, err)
logger.error("Failed to remove %s: %s", strategy.name, err)
return False

return self.snap_client.present is False
Expand All @@ -379,7 +387,7 @@ def set(self, snap_config: dict) -> bool:
try:
self.snap_client.set(snap_config, typed=True)
except snap.SnapError as err:
logger.error("Failed to update snap configs %s: %s", self.strategy.snap, err)
logger.error("Failed to update snap configs %s: %s", self.exporter_name, err)
return False
return True

Expand All @@ -388,53 +396,56 @@ def check_health(self) -> bool:
Returns true if the service is healthy, false otherwise.
"""
return self.strategy.check()
return all(strategy.check() for strategy in self.strategies)

def configure(self) -> bool:
"""Set the necessary exporter configurations or change snap channel.
Returns true if the configure is successful, false otherwise.
"""
return self.install()
for strategy in self.strategies:
if isinstance(strategy, SnapStrategy):
try:
# refresh the snap for a new channel if necessary
strategy.install()
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to configure %s: %s", self.exporter_name, err)
return False
return True


class DCGMExporter(SnapExporter):
"""A class representing the DCGM exporter and the metric endpoints."""

exporter_name: str = "dcgm"
port: int = 9400
snap_common: Path = Path("/var/snap/dcgm/common/")
metric_config: str = "dcgm-exporter-metrics-file"

def __init__(self, charm_dir: Path, config: ConfigData):
def __init__(self, config: ConfigData):
"""Init."""
self.strategy = DCGMExporterStrategy(str(config["dcgm-snap-channel"]))
self.charm_dir = charm_dir
self.metrics_file = self.charm_dir / "src/gpu_metrics/dcgm_metrics.csv"
self.metric_config_value = self.metrics_file.name
self.strategies = [
DCGMExporterStrategy(str(config["dcgm-snap-channel"])),
NVIDIADriverStrategy(),
]
super().__init__(config)

def install(self) -> bool:
"""Install the DCGM exporter and configure custom metrics."""
if not super().install():
return False

logger.info("Creating a custom metrics file and configuring the DCGM snap to use it")
try:
shutil.copy(self.metrics_file, self.snap_common)
self.snap_client.set({self.metric_config: self.metric_config_value})
self.snap_client.restart(reload=True)
except Exception as err: # pylint: disable=broad-except
logger.error("Failed to configure custom DCGM metrics: %s", err)
return False

return True

@staticmethod
def hw_tools() -> Set[HWTool]:
"""Return hardware tools to watch."""
return {HWTool.DCGM}

def validate_exporter_configs(self) -> Tuple[bool, str]:
"""Validate the if the DCGM exporter is able to run."""
valid, msg = super().validate_exporter_configs()
if not valid:
return False, msg

if not NVIDIADriverStrategy().check():
return (
False,
"Failed to communicate with NVIDIA driver. See more details in the logs",
)
return valid, msg


class SmartCtlExporter(SnapExporter):
"""A class representing the smartctl exporter and the metric endpoints."""
Expand All @@ -445,7 +456,7 @@ def __init__(self, config: ConfigData) -> None:
"""Initialize the SmartctlExporter class."""
self.port = int(config["smartctl-exporter-port"])
self.log_level = str(config["exporter-log-level"])
self.strategy = SmartCtlExporterStrategy(str(config["smartctl-exporter-snap-channel"]))
self.strategies = [SmartCtlExporterStrategy(str(config["smartctl-exporter-snap-channel"]))]
super().__init__(config)

@staticmethod
Expand Down
Loading

0 comments on commit 98b502e

Please sign in to comment.