Skip to content

Commit

Permalink
feat(slurmd): implement option to install ROCm drivers
Browse files Browse the repository at this point in the history
Leaving this as an option that can be set by the user for now. We can
investigate in the future how to automatically install the drivers.
  • Loading branch information
jedel1043 committed Sep 18, 2024
1 parent 246f1da commit ae95a40
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 2 deletions.
14 changes: 14 additions & 0 deletions charms/slurmd/charmcraft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ config:
$ juju config slurmd nhc-conf="$(cat extra-nhc.conf)"
```
gpu:
default: ""
type: string
description: >
Type of GPU driver to install.
Available options: [`amd`].
Example usage:
```bash
$ juju config slurmd gpu=amd
```
actions:
node-configured:
description: Remove a node from DownNodes when the reason is `New node`.
Expand Down
12 changes: 11 additions & 1 deletion charms/slurmd/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,18 @@ def _on_install(self, event: InstallEvent) -> None:

self._check_status()

def _on_config_changed(self, event: ConfigChangedEvent) -> None:
def _on_config_changed(self, event: ConfigChangedEvent) -> None: # noqa: C901
"""Handle charm configuration changes."""
if gpu := self.model.config.get("gpu"):
if gpu == "amd":
if not self._slurmd_manager.rocm_manager.is_installed():
self._slurmd_manager.rocm_manager.install()
self.unit.reboot(now=True)
self._slurmd_manager.rocm_manager.post_install()
else:
self.unit.status = BlockedStatus("Invalid value for option `gpu`")
return

if nhc_conf := self.model.config.get("nhc-conf"):
if nhc_conf != self._stored.nhc_conf:
self._stored.nhc_conf = nhc_conf
Expand Down
55 changes: 55 additions & 0 deletions charms/slurmd/src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,58 @@
=cs1s
-----END PGP PUBLIC KEY BLOCK-----
"""

ROCM_PPA_KEY = """
-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: GnuPG v1
mQINBFefsSABEADmVqQyRi5bcUs/eG8mnKLdY+V+xuKuHLuujlXinSaMFRO640Md
C2HNYLSd58Z8cB1rKfiN639CZp+SkDWq60cFXDCcX9djT0JmBzsTD/gwoMr16tMY
O+Z2mje2pEYgDJdmYrephhXn29BfebW1IQKdA+4C7l675mJ/T8yVMUNXC0hqfGDA
h1MJUQy/lz1S2fGdjCKX0PiYOnCOyhNa7aTpw9PkZWgEa/s4BhplFZxvLohrCcf6
ks0gUITHfeEhJvj2KurRfL68DgFifGnG+/fsMHgW1Xp19GsnIVaoh6cV7/iFHhrb
6YHI1fdOq/mwOfG8mJnXmDXC/o24Q7mRRwvoJcsT0j+thRirs8trV01mKY+7Hxd2
CamWttibo062pjWN2aEUMPmEU2kmGOupsZtlpqn6SGCd2+6maOPMNEq/F0EWxhul
q6mgezVb8pvJ3bwvph2/lMSgfT9fHs6UIh4i/3rnA5/JaejFonlnS9xEuglKjklj
UoikSPBOwjvoPW2u99WCflURFSXVvuk7Ci+XkbVPIZyD6gFJjeY02Ic5MAv5tj/z
0fpgr/CfwEllms+z7qz768xRweA0kmPTTARdufVTna6EV3K3njxvCIIfnrp1cF6S
e3VrREd98gO0Rmzy74UFqkXl9Tb/+UILx1qVRmOBinwacKGqzo+k9jPUKQARAQAB
tChBTUQgTUxTRSBEZXZPcHMgPGRsLk1MU0UuRGV2T3BzQGFtZC5jb20+iQI+BBMB
AgAoAhsDBgsJCAcDAgYVCAIJCgsEFgIDAQIeAQIXgAUCYfuRkwUJE8Hh5wAKCRCT
hrSKGmk8XI1AEACSJLVGHCLJOOKz9fbUR4KWl7Gpv0RWccwxhH01jNZTSXUCEnKA
2KYmaqFvrT5szxWILobmCNYtAlbdkpUfb0mMaF3UtTu+1UMOw2ExzxHw1FyA+z6d
vLqDKXLldsOFUfojDUhD5cK6uvONPc1orCf/4ve6wnRG838bAzb4VrFR64IxfPjx
NukH+jo2nEXNpnNv44DEiq65CcObaPuwAVBFnRYD/ByPO4ZArxFXqNzHRxpoZkKv
iwzhbPG4cirioqzRR9y2SsC+a2sO4a/jH0wOL2+n4L86xShYcuCBxXvS/AwrV/aO
JxKOfAUV4VQegAOQz64L+iz7PslNSTILJGdvGcC5Ckgpo6evdWBT7KdGXhzf4S1f
wZjYyP9sfQa7LxqyrkLHZqYt4If4Jmukx7cApBYp1nPnuCQrLU6D4Arq0ZVWQuNV
hbABLeqwdVQcX+vG/Kr/ZC+Vkv3Z8oElwVGAAQ6HNXr/u8ud2bu6iNJ5mcQbM1HD
KTNt5LUrk0p588a8dk0/TyC5xeKSv51iNL+aOVaTr0pRwgaHtEVar2i0FPC1mkr4
1hhIDddx8WLoUt/52f1juyr/4CpL1M5f1cbMVjV6i0kqIEx/hxrryc+fZZQT5R4M
vysxcsh8ttgpABG5vzz2rLOCanmQ4eDdmlugzn/u0ngoDdnC0gEfnVVutLkCDQRX
n7EgARAAlsWVKSOQicuBxBlo3U5tre5whSyAOWHuy6/heGwCkGssTahbIL8pRwOL
5nKJCPCKKJ4YYoZ+Jzer9WTsDRZU/zpQXK9C5WdfF6DN/Fai3lqhgeDDVyF0hUDr
NQigm/w66JEYTGtMcC5PnYv7S6Zrn9WN4anv9n5thNwfsqxpbbg6sAQ2aLHLsW96
myQE9v1s0YoSZYc7rFYBwszE+tFX0kLlyBYSRVns/USQifu66RObO706d8DHp6Ro
vO6WgsTu+0RR2FEUabBx1q6iKe1cqK0FYtWd8tXCpqQBm0zGC6UwTp4Z4GMCX2Pk
3xAMmrItW5kPKCANB+P/8ZoOoZLIX5Fr9axQ496lUh0ZDhOACewJfj9Szk9GN5rq
+2QKnRepatevGBVaN0lCAEwg2q9/9xmrT6CixFrbnw2T6mWHM3jQrvduqmC0c1Cd
uMZBGDKSpjouaN0UKtC+udwWiY7w452pcjCnUjzjk7tR1IarSCnLLYeb+MDCK83M
CFH60SmBfdqjRiTiLas34KSKNnmbfUfrTYswf0Oed/qXAUSlYOCmWl4sV8n+Ebpy
XfY80/fzu95RbpMEZMhUTRtvr64O5jaWM/lFnubnegGTW3Bk/fBR2VRsBx56ZHlc
JH23f6IREjQ1x4B2UsINYfyYpmzb+R4qpMzycBVHv9ipiYQsQ8sAEQEAAYkCJQQY
AQoADwIbDAUCYfuRtwUJE8HiEAAKCRCThrSKGmk8XMAcEACd0jYXjnu7qoEY4U9Q
47X2SeJmWsuTavCrU5AWxjYwWd0mtDqK8EynxDPq7UFs+8+OukqrE++p0bfBbDl9
TwnwmSSdizAZriHMSgeg9GR5KVL4mreNhFQdk/6mTFdlRhi5s7ZuvPayLSMIAWaj
ET5gFMeO1B/ABSpaKEZwQjRcXrto/hCUJ++7qoosblhcgwX7fiqZZbMxcoCEQIQQ
7ZasLxpVtaeDVfetp2zO5F0/e3D/sNbvBrlDofSt6D5V2cmKjLqONFVc6JrzSNeK
k9Gn8UVzAKfRfLaQyDaoFV0MbBf3q111UQQPkvwZYp0lPT6t2/G8zoubwFhHsM31
K5ZBbt0384hI9RJITo9/krXVXLYFeCLcoPKn/fGWgAwyYAYr6C7JcocxTNUyCd1I
AVg4SO/JuC3NWFQK5LhknN/gJkFlLZdB2cWqu9dDIkx1cHXThaM2n/7GSxv7fzrI
Br1jhZjUPWJ2iOd8iHgVEkIEvZql8z+huSxcNemodEN1emmUUoIyY3Fh0lJmozDt
ZPATk3iPpksOApsDVhWXP96RjTYEozYCxgTxCnk+kX/iJIlt53BPNWm9HMTcmtDI
v3s7OEcw0DN3U2VKcL9Q4Sg3uNfhwQsw/xBJaxAHQn5lN/8t0eLt+U653ooEEr0o
ta5TfPumStSQ1UjP8pPny4l+JQ==
=UOE+
-----END PGP PUBLIC KEY BLOCK-----
"""
85 changes: 84 additions & 1 deletion charms/slurmd/src/slurmd_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any, Dict

import distro
from constants import MUNGE_KEY_PATH, SLURM_GROUP, SLURM_USER, UBUNTU_HPC_PPA_KEY
from constants import MUNGE_KEY_PATH, ROCM_PPA_KEY, SLURM_GROUP, SLURM_USER, UBUNTU_HPC_PPA_KEY

import charms.operator_libs_linux.v0.apt as apt # type: ignore [import-untyped]
import charms.operator_libs_linux.v1.systemd as systemd # type: ignore [import-untyped]
Expand Down Expand Up @@ -130,6 +130,88 @@ def install(self) -> bool:
return package_installed


class ROCmPackagesLifecycleManager:
"""Facilitate ROCm packages lifecycles."""

_LD_FILE = Path("/etc/ld.so.conf.d/rocm.conf")
_PATH_FILE = Path("PATH=$PATH:/opt/rocm-6.2.0/bin")

def install(self) -> None:
"""Install ROCm required packages.
NOTE: The machine will require a restart after installation to enable the `amdgpu-dkms` module.
Raises:
SlurmdException if the installation failed.
"""
try:
arch = subprocess.check_output(["dpkg", "--print-architecture"], text=True).strip()
except subprocess.CalledProcessError as e:
raise SlurmdException(
f"failed to determine the system architecture. reason: {e.stderr}"
)

amdgpu_repo = apt.DebianRepository(
enabled=True,
repotype="deb",
uri="https://repo.radeon.com/amdgpu/6.2/ubuntu",
release=distro.codename(),
groups=["main"],
options={"arch": arch},
)
amdgpu_repo.import_key(ROCM_PPA_KEY)

rocm_repo = apt.DebianRepository(
enabled=True,
repotype="deb",
uri="https://repo.radeon.com/rocm/apt/6.2",
release=distro.codename(),
groups=["main"],
options={"arch": arch},
)
rocm_repo.import_key(ROCM_PPA_KEY)

repositories = apt.RepositoryMapping()
repositories.add(amdgpu_repo)
repositories.add(rocm_repo)

rocm_pin = Path("/etc/apt/preferences.d/rocm-pin-600")
rocm_pin.parent.mkdir(parents=True, exist_ok=True)
rocm_pin.write_text("Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600")

apt.update()
apt.add_package(["amdgpu-dkms", "rocm"])

def is_installed(self) -> bool:
"""Check if the ROCm packages have been installed."""
try:
for package in ["amdgpu-dkms", "rocm"]:
apt.DebianPackage.from_installed_package(package)
except apt.PackageNotFoundError:
return False

return True

def post_install(self) -> None:
"""Execute post-installation steps for the ROCm packages.
Raises:
SlurmdException if the post-installation steps failed.
"""
self._LD_FILE.parent.mkdir(parents=True, exist_ok=True)
self._LD_FILE.write_text("/opt/rocm/lib\n/opt/rocm/lib64")

try:
subprocess.check_call(["ldconfig"], text=True)
except subprocess.CalledProcessError as e:
raise SlurmdException(
f"could not execute the post-installation step. reason: {e.stdout}"
)

self._PATH_FILE.parent.mkdir(parents=True, exist_ok=True)
self._PATH_FILE.write_text("PATH=$PATH:/opt/rocm-6.2.0/bin")


class SlurmdManager:
"""SlurmdManager."""

Expand All @@ -138,6 +220,7 @@ def __init__(self):
self._slurmd_package = CharmedHPCPackageLifecycleManager("slurmd")
self._slurm_client_package = CharmedHPCPackageLifecycleManager("slurm-client")
self._common_packages = CommonPackagesLifecycleManager()
self.rocm_manager = ROCmPackagesLifecycleManager()

def install(self) -> bool:
"""Install slurmd, slurm-client, munge, and common packages to the system."""
Expand Down

0 comments on commit ae95a40

Please sign in to comment.