Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support mig devices #53

Merged
merged 9 commits into from
Aug 22, 2024
Merged
3 changes: 2 additions & 1 deletion .go-version
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
1.22.4
1.22.6

3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ RPC. GPUs can be excluded from fingerprinting by setting the `ignored_gpu_ids`
field (see below). Plugin sends statistics for fingerprinted devices every
`stats_period` period.

The plugin will be able to distinguish whether the GPU has [`Multi-Instance GPU (MIG)`](https://www.nvidia.com/en-us/technologies/multi-instance-gpu/) enabled.
When enabled all instances will be fingerprinted as individual GPUs that can be addressed accordingly.
shoenig marked this conversation as resolved.
Show resolved Hide resolved

## Config

The plugin is configured in the Nomad client's
Expand Down
2 changes: 1 addition & 1 deletion device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package nvidia
import (
"testing"

hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad-device-nvidia/nvml"
"github.com/hashicorp/nomad/plugins/device"
"github.com/shoenig/test/must"
Expand Down
64 changes: 44 additions & 20 deletions nvml/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package nvml

import (
"cmp"
"fmt"
"slices"
)

// DeviceData represents common fields for Nvidia device
Expand Down Expand Up @@ -95,28 +97,32 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) {
*/

// Assumed that this method is called with receiver retrieved from
// NewNvmlClient
// because this method handles initialization of NVML library
// NewNvmlClient because this method handles initialization of NVML library

driverVersion, err := c.driver.SystemDriverVersion()
if err != nil {
return nil, fmt.Errorf("nvidia nvml SystemDriverVersion() error: %v\n", err)
}

numDevices, err := c.driver.DeviceCount()
deviceUUIDs, err := c.driver.ListDeviceUUIDs()
if err != nil {
return nil, fmt.Errorf("nvidia nvml DeviceCount() error: %v\n", err)
return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err)
}

allNvidiaGPUResources := make([]*FingerprintDeviceData, numDevices)
allNvidiaGPUResources := make([]*FingerprintDeviceData, 0, len(deviceUUIDs))

for i := 0; i < int(numDevices); i++ {
deviceInfo, err := c.driver.DeviceInfoByIndex(uint(i))
for uuid, mode := range deviceUUIDs {
// do not care about phsyical parents of MIGs
if mode == parent {
continue
}

deviceInfo, err := c.driver.DeviceInfoByUUID(uuid)
if err != nil {
return nil, fmt.Errorf("nvidia nvml DeviceInfoByIndex() error: %v\n", err)
return nil, fmt.Errorf("nvidia nvml DeviceInfoByUUID() error: %v\n", err)
}

allNvidiaGPUResources[i] = &FingerprintDeviceData{
allNvidiaGPUResources = append(allNvidiaGPUResources, &FingerprintDeviceData{
DeviceData: &DeviceData{
DeviceName: deviceInfo.Name,
UUID: deviceInfo.UUID,
Expand All @@ -130,8 +136,13 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) {
DisplayState: deviceInfo.DisplayState,
PersistenceMode: deviceInfo.PersistenceMode,
PCIBusID: deviceInfo.PCIBusID,
}
})

slices.SortFunc(allNvidiaGPUResources, func(a, b *FingerprintDeviceData) int {
return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID)
})
}

return &FingerprintData{
Devices: allNvidiaGPUResources,
DriverVersion: driverVersion,
Expand All @@ -156,23 +167,32 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) {
*/

// Assumed that this method is called with receiver retrieved from
// NewNvmlClient
// because this method handles initialization of NVML library
// NewNvmlClient because this method handles initialization of NVML library

numDevices, err := c.driver.DeviceCount()
deviceUUIDs, err := c.driver.ListDeviceUUIDs()
if err != nil {
return nil, fmt.Errorf("nvidia nvml DeviceCount() error: %v\n", err)
return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err)
}

allNvidiaGPUStats := make([]*StatsData, numDevices)
allNvidiaGPUStats := make([]*StatsData, 0, len(deviceUUIDs))

for i := 0; i < int(numDevices); i++ {
deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByIndex(uint(i))
for uuid, mode := range deviceUUIDs {

// A30/A100 MIG devices have no stats.
//
// https://docs.nvidia.com/datacenter/tesla/mig-user-guide/#telemetry
//
// Is this fixed on H100 or later? Maybe?
if mode == mig || mode == parent {
continue
}
Comment on lines +181 to +188
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be safe to attempt to call DeviceInfoAndStatusByUUID and log/continue on error? I'd just hate for this to be something NVidia fixes in a driver update and then our plugin languishes for months without support because we don't even try.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we log we end up spamming a log line for each MIG device for each period. In the sad case that's 7 MIG devices for 8 GPUs every 30 seconds which is a lot of log spam for the hope Nvidia will fix their stuff.


deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByUUID(uuid)
if err != nil {
return nil, fmt.Errorf("nvidia nvml DeviceInfoAndStatusByIndex() error: %v\n", err)
return nil, fmt.Errorf("nvidia nvml DeviceInfoAndStatusByUUID() error: %v\n", err)
}

allNvidiaGPUStats[i] = &StatsData{
allNvidiaGPUStats = append(allNvidiaGPUStats, &StatsData{
DeviceData: &DeviceData{
DeviceName: deviceInfo.Name,
UUID: deviceInfo.UUID,
Expand All @@ -191,7 +211,11 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) {
ECCErrorsL1Cache: deviceStatus.ECCErrorsL1Cache,
ECCErrorsL2Cache: deviceStatus.ECCErrorsL2Cache,
ECCErrorsDevice: deviceStatus.ECCErrorsDevice,
}
})

slices.SortFunc(allNvidiaGPUStats, func(a, b *StatsData) int {
return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID)
})
}
return allNvidiaGPUStats, nil
}
Loading