Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monitor - Add support for AMD GPU. #580

Merged
merged 10 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def run(self):
**x,
'develop': x['dev'] + x['test'],
'cpuworker': x['torch'],
'amdworker': x['torch'] + x['ort'],
'amdworker': x['torch'] + x['ort'] + x['amd'],
'nvworker': x['torch'] + x['ort'] + x['nvidia'],
}
)(
Expand Down Expand Up @@ -217,6 +217,7 @@ def run(self):
'onnxruntime-gpu; python_version>="3.10"',
],
'nvidia': ['py3nvml>=0.2.6'],
'amd': ['pyrsmi>=1.0.1'],
}
),
include_package_data=True,
Expand Down
243 changes: 234 additions & 9 deletions superbench/common/utils/device_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,138 @@

"""Device Managerment Library Utility."""

import py3nvml.py3nvml as nvml
from typing import Optional

from superbench.common.utils import logger
from superbench.common.utils import process
from superbench.common.devices import GPU

gpu = GPU()
if gpu.vendor == 'nvidia' or gpu.vendor == 'nvidia-graphics':
import py3nvml.py3nvml as nvml
elif gpu.vendor == 'amd' or gpu.vendor == 'amd-graphics':
from pyrsmi import rocml


class DeviceManager:
"""Device management module."""
"""Device management base module."""
def __init__(self):
"""Constructor."""
nvml.nvmlInit()
self._device_count = self.get_device_count()

def get_device_count(self):
"""Get the number of device.

Return:
count (int): count of device.
"""
return 0

def get_device_compute_capability(self):
"""Get the compute capability of device.

Return:
cap (float): the compute capability of device, None means failed to get the data.
"""
return None

def get_device_utilization(self, idx):
"""Get the utilization of device.

Args:
idx (int): device index.

Return:
util (int): the utilization of device, None means failed to get the data.
"""
return None

def get_device_temperature(self, idx):
"""Get the temperature of device, unit: celsius.

Args:
idx (int): device index.

Return:
temp (int): the temperature of device, None means failed to get the data.
"""
return None

def get_device_power(self, idx):
"""Get the realtime power of device, unit: watt.

Args:
idx (int): device index.

Return:
temp (int): the realtime power of device, None means failed to get the data.
"""
return None

def get_device_power_limit(self, idx):
"""Get the power management limit of device, unit: watt.

Args:
idx (int): device index.

Return:
temp (int): the power management limit of device, None means failed to get the data.
"""
return None

def get_device_memory(self, idx):
"""Get the memory information of device, unit: byte.

Args:
idx (int): device index.

Return:
used (int): the used device memory in bytes, None means failed to get the data.
total (int): the total device memory in bytes, None means failed to get the data.
"""
return None, None

def get_device_row_remapped_info(self, idx):
"""Get the row remapped information of device.

Args:
idx (int): device index.

Return:
remapped_metrics (dict): the row remapped information, None means failed to get the data.
"""
return None

def get_device_ecc_error(self, idx):
"""Get the ecc error information of device.

Args:
idx (int): device index.

Return:
corrected_ecc (int) : the count of single bit ecc error.
uncorrected_ecc (int): the count of double bit ecc error.
"""
return None, None


class NvidiaDeviceManager(DeviceManager):
"""Device management module for Nvidia."""
def __init__(self):
"""Constructor."""
nvml.nvmlInit()
super().__init__()

self._device_handlers = list()
for i in range(self._device_count):
self._device_handlers.append(nvml.nvmlDeviceGetHandleByIndex(i))

def __del__(self):
"""Destructor."""
nvml.nvmlShutdown()

def get_device_count(self):
"""Get the compute capability of device.
"""Get the number of device.

Return:
count (int): count of device.
Expand Down Expand Up @@ -79,7 +193,7 @@ def get_device_power(self, idx):
idx (int): device index.

Return:
temp (float): the realtime power of device, None means failed to get the data.
temp (int): the realtime power of device, None means failed to get the data.
"""
try:
power = nvml.nvmlDeviceGetPowerUsage(self._device_handlers[idx])
Expand All @@ -95,7 +209,7 @@ def get_device_power_limit(self, idx):
idx (int): device index.

Return:
temp (float): the power management limit of device, None means failed to get the data.
temp (int): the power management limit of device, None means failed to get the data.
"""
try:
powerlimit = nvml.nvmlDeviceGetPowerManagementLimit(self._device_handlers[idx])
Expand All @@ -111,8 +225,8 @@ def get_device_memory(self, idx):
idx (int): device index.

Return:
used (float): the used device memory, None means failed to get the data.
total (float): the total device memory, None means failed to get the data.
used (int): the used device memory in bytes, None means failed to get the data.
total (int): the total device memory in bytes, None means failed to get the data.
"""
try:
mem = nvml.nvmlDeviceGetMemoryInfo(self._device_handlers[idx])
Expand Down Expand Up @@ -208,4 +322,115 @@ def get_device_ecc_error(self, idx):
return corrected_ecc, uncorrected_ecc


device_manager = DeviceManager()
class AmdDeviceManager(DeviceManager):
"""Device management module for AMD."""
def __init__(self):
"""Constructor."""
rocml.smi_initialize()
super().__init__()

def __del__(self):
"""Destructor."""
rocml.smi_shutdown()

def get_device_count(self):
"""Get the number of device.

Return:
count (int): count of device.
"""
return rocml.smi_get_device_count()

def get_device_utilization(self, idx):
"""Get the utilization of device.

Args:
idx (int): device index.

Return:
util (int): the utilization of device, None means failed to get the data.
"""
try:
util = rocml.smi_get_device_utilization(idx)
except Exception as err:
logger.error('Get device utilization failed: {}'.format(str(err)))
return None
return util

def get_device_temperature(self, idx):
"""Get the temperature of device, unit: celsius.

Args:
idx (int): device index.

Return:
temp (int): the temperature of device, None means failed to get the data.
"""
# Currently no API provided in rocml.
return None

def get_device_power(self, idx):
"""Get the realtime power of device, unit: watt.

Args:
idx (int): device index.

Return:
temp (int): the realtime power of device, None means failed to get the data.
"""
try:
power = rocml.smi_get_device_average_power(idx)
except Exception as err:
logger.error('Get device power failed: {}'.format(str(err)))
return None
return int(int(power) / 1000)

def get_device_power_limit(self, idx):
"""Get the power management limit of device, unit: watt.

Args:
idx (int): device index.

Return:
temp (int): the power management limit of device, None means failed to get the data.
"""
# Currently no API provided in rocml.
return None

def get_device_memory(self, idx):
"""Get the memory information of device, unit: byte.

Args:
idx (int): device index.

Return:
used (int): the used device memory in bytes, None means failed to get the data.
total (int): the total device memory in bytes, None means failed to get the data.
"""
try:
mem_used = rocml.smi_get_device_memory_used(idx)
mem_total = rocml.smi_get_device_memory_total(idx)
except Exception as err:
logger.error('Get device memory failed: {}'.format(str(err)))
return None, None
return mem_used, mem_total

def get_device_ecc_error(self, idx):
"""Get the ecc error information of device.

Args:
idx (int): device index.

Return:
corrected_ecc (int) : the count of single bit ecc error.
uncorrected_ecc (int): the count of double bit ecc error.
"""
# Currently no API provided in rocml.
return None, None


device_manager: Optional[DeviceManager] = DeviceManager()
if gpu.vendor == 'nvidia' or gpu.vendor == 'nvidia-graphics':
device_manager = NvidiaDeviceManager()
elif gpu.vendor == 'amd' or gpu.vendor == 'amd-graphics':
device_manager = AmdDeviceManager()
Loading