Skip to content

Commit

Permalink
wip: avoid resources of migs
Browse files Browse the repository at this point in the history
  • Loading branch information
shoenig committed Aug 20, 2024
1 parent 6c243c3 commit 4f78421
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 18 deletions.
45 changes: 34 additions & 11 deletions nvml/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package nvml

import (
"cmp"
"fmt"
"slices"

"github.com/shoenig/netlog"
)
Expand Down Expand Up @@ -112,17 +114,22 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) {
return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err)
}

allNvidiaGPUResources := make([]*FingerprintDeviceData, len(deviceUUIDs))
allNvidiaGPUResources := make([]*FingerprintDeviceData, 0, len(deviceUUIDs))

for i, element := range deviceUUIDs {
nl.Info("GFP", "i", i, "element", element)
for uuid, mode := range deviceUUIDs {
nl.Info("GFP", "uuid", uuid, "mode", mode)

deviceInfo, err := c.driver.DeviceInfoByUUID(element)
// do not care about phsyical parents of MIGs
if mode == parent {
continue
}

deviceInfo, err := c.driver.DeviceInfoByUUID(uuid)
if err != nil {
return nil, fmt.Errorf("nvidia nvml DeviceInfoByUUID() error: %v\n", err)
}

allNvidiaGPUResources[i] = &FingerprintDeviceData{
allNvidiaGPUResources = append(allNvidiaGPUResources, &FingerprintDeviceData{
DeviceData: &DeviceData{
DeviceName: deviceInfo.Name,
UUID: deviceInfo.UUID,
Expand All @@ -136,8 +143,13 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) {
DisplayState: deviceInfo.DisplayState,
PersistenceMode: deviceInfo.PersistenceMode,
PCIBusID: deviceInfo.PCIBusID,
}
})

slices.SortFunc(allNvidiaGPUResources, func(a, b *FingerprintDeviceData) int {
return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID)
})
}

return &FingerprintData{
Devices: allNvidiaGPUResources,
DriverVersion: driverVersion,
Expand Down Expand Up @@ -170,15 +182,22 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) {
return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err)
}

allNvidiaGPUStats := make([]*StatsData, len(deviceUUIDs))
allNvidiaGPUStats := make([]*StatsData, 0, len(deviceUUIDs))

for i, element := range deviceUUIDs {
deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByUUID(element)
for uuid, mode := range deviceUUIDs {

// MIG devices have no stats.
if mode == mig || mode == parent {
continue
}

deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByUUID(uuid)
if err != nil {
nl.Error("A DeviceInfoAndStatusByUUID failed", "uuid", uuid, "mode", mode)
return nil, fmt.Errorf("nvidia nvml DeviceInfoAndStatusByUUID() error: %v\n", err)
}

allNvidiaGPUStats[i] = &StatsData{
allNvidiaGPUStats = append(allNvidiaGPUStats, &StatsData{
DeviceData: &DeviceData{
DeviceName: deviceInfo.Name,
UUID: deviceInfo.UUID,
Expand All @@ -197,7 +216,11 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) {
ECCErrorsL1Cache: deviceStatus.ECCErrorsL1Cache,
ECCErrorsL2Cache: deviceStatus.ECCErrorsL2Cache,
ECCErrorsDevice: deviceStatus.ECCErrorsDevice,
}
})

slices.SortFunc(allNvidiaGPUStats, func(a, b *StatsData) int {
return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID)
})
}
return allNvidiaGPUStats, nil
}
19 changes: 13 additions & 6 deletions nvml/driver_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@ func (n *nvmlDriver) SystemDriverVersion() (string, error) {
return version, nil
}

// List all compute device UUIDs in the system, includes MIG devices
// but excludes their "parent".
func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) {
// List all compute device UUIDs in the system.
// Includes all instances, including normal GPUs, MIGs, and their physical parents.
// Each UUID is associated with a mode indication which type it is.
func (n *nvmlDriver) ListDeviceUUIDs() (map[string]mode, error) {
count, code := nvml.DeviceGetCount()
if code != nvml.SUCCESS {
return nil, decode("failed to get device count", code)
}

var uuids []string
uuids := make(map[string]mode)

for i := 0; i < int(count); i++ {
device, code := nvml.DeviceGetHandleByIndex(int(i))
if code != nvml.SUCCESS {
Expand All @@ -66,7 +68,7 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) {
return nil, decode("failed to get device %d uuid", code)
}

uuids = append(uuids, uuid)
uuids[uuid] = normal
continue
}
if code != nvml.SUCCESS {
Expand All @@ -78,6 +80,11 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) {
return nil, decode("failed to get device MIG device count", code)
}

uuid, code := nvml.DeviceGetUUID(device)
if code == nvml.SUCCESS {
uuids[uuid] = parent
}

for j := 0; j < int(migCount); j++ {
migDevice, code := nvml.DeviceGetMigDeviceHandleByIndex(device, int(j))
if code == nvml.ERROR_NOT_FOUND || code == nvml.ERROR_INVALID_ARGUMENT {
Expand All @@ -91,7 +98,7 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) {
if code != nvml.SUCCESS {
return nil, decode(fmt.Sprintf("failed to get mig device uuid %d", j), code)
}
uuids = append(uuids, uuid)
uuids[uuid] = mig
}
}

Expand Down
10 changes: 9 additions & 1 deletion nvml/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ var (
UnavailableLib = errors.New("could not load NVML library")
)

type mode int

const (
normal mode = iota
parent
mig
)

// nvmlDriver implements NvmlDriver
// Users are required to call Initialize method before using any other methods
type nvmlDriver struct{}
Expand All @@ -19,7 +27,7 @@ type NvmlDriver interface {
Initialize() error
Shutdown() error
SystemDriverVersion() (string, error)
ListDeviceUUIDs() ([]string, error)
ListDeviceUUIDs() (map[string]mode, error)
DeviceInfoByUUID(string) (*DeviceInfo, error)
DeviceInfoAndStatusByUUID(string) (*DeviceInfo, *DeviceStatus, error)
}
Expand Down

0 comments on commit 4f78421

Please sign in to comment.