From 4f7842162a132d4fc9fb308fe54ea9eaf5165a8b Mon Sep 17 00:00:00 2001 From: Seth Hoenig Date: Tue, 20 Aug 2024 20:21:59 +0000 Subject: [PATCH] wip: avoid resources of migs --- nvml/client.go | 45 +++++++++++++++++++++++++++++++++----------- nvml/driver_linux.go | 19 +++++++++++++------ nvml/shared.go | 10 +++++++++- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/nvml/client.go b/nvml/client.go index 35f59c3..fefd399 100644 --- a/nvml/client.go +++ b/nvml/client.go @@ -4,7 +4,9 @@ package nvml import ( + "cmp" "fmt" + "slices" "github.com/shoenig/netlog" ) @@ -112,17 +114,22 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) { return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err) } - allNvidiaGPUResources := make([]*FingerprintDeviceData, len(deviceUUIDs)) + allNvidiaGPUResources := make([]*FingerprintDeviceData, 0, len(deviceUUIDs)) - for i, element := range deviceUUIDs { - nl.Info("GFP", "i", i, "element", element) + for uuid, mode := range deviceUUIDs { + nl.Info("GFP", "uuid", uuid, "mode", mode) - deviceInfo, err := c.driver.DeviceInfoByUUID(element) + // do not care about phsyical parents of MIGs + if mode == parent { + continue + } + + deviceInfo, err := c.driver.DeviceInfoByUUID(uuid) if err != nil { return nil, fmt.Errorf("nvidia nvml DeviceInfoByUUID() error: %v\n", err) } - allNvidiaGPUResources[i] = &FingerprintDeviceData{ + allNvidiaGPUResources = append(allNvidiaGPUResources, &FingerprintDeviceData{ DeviceData: &DeviceData{ DeviceName: deviceInfo.Name, UUID: deviceInfo.UUID, @@ -136,8 +143,13 @@ func (c *nvmlClient) GetFingerprintData() (*FingerprintData, error) { DisplayState: deviceInfo.DisplayState, PersistenceMode: deviceInfo.PersistenceMode, PCIBusID: deviceInfo.PCIBusID, - } + }) + + slices.SortFunc(allNvidiaGPUResources, func(a, b *FingerprintDeviceData) int { + return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID) + }) } + return &FingerprintData{ Devices: allNvidiaGPUResources, DriverVersion: driverVersion, @@ -170,15 +182,22 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) { return nil, fmt.Errorf("nvidia nvml ListDeviceUUIDs() error: %v\n", err) } - allNvidiaGPUStats := make([]*StatsData, len(deviceUUIDs)) + allNvidiaGPUStats := make([]*StatsData, 0, len(deviceUUIDs)) - for i, element := range deviceUUIDs { - deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByUUID(element) + for uuid, mode := range deviceUUIDs { + + // MIG devices have no stats. + if mode == mig || mode == parent { + continue + } + + deviceInfo, deviceStatus, err := c.driver.DeviceInfoAndStatusByUUID(uuid) if err != nil { + nl.Error("A DeviceInfoAndStatusByUUID failed", "uuid", uuid, "mode", mode) return nil, fmt.Errorf("nvidia nvml DeviceInfoAndStatusByUUID() error: %v\n", err) } - allNvidiaGPUStats[i] = &StatsData{ + allNvidiaGPUStats = append(allNvidiaGPUStats, &StatsData{ DeviceData: &DeviceData{ DeviceName: deviceInfo.Name, UUID: deviceInfo.UUID, @@ -197,7 +216,11 @@ func (c *nvmlClient) GetStatsData() ([]*StatsData, error) { ECCErrorsL1Cache: deviceStatus.ECCErrorsL1Cache, ECCErrorsL2Cache: deviceStatus.ECCErrorsL2Cache, ECCErrorsDevice: deviceStatus.ECCErrorsDevice, - } + }) + + slices.SortFunc(allNvidiaGPUStats, func(a, b *StatsData) int { + return cmp.Compare(a.DeviceData.UUID, b.DeviceData.UUID) + }) } return allNvidiaGPUStats, nil } diff --git a/nvml/driver_linux.go b/nvml/driver_linux.go index 8673b41..2a650f3 100644 --- a/nvml/driver_linux.go +++ b/nvml/driver_linux.go @@ -40,15 +40,17 @@ func (n *nvmlDriver) SystemDriverVersion() (string, error) { return version, nil } -// List all compute device UUIDs in the system, includes MIG devices -// but excludes their "parent". -func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) { +// List all compute device UUIDs in the system. +// Includes all instances, including normal GPUs, MIGs, and their physical parents. +// Each UUID is associated with a mode indication which type it is. +func (n *nvmlDriver) ListDeviceUUIDs() (map[string]mode, error) { count, code := nvml.DeviceGetCount() if code != nvml.SUCCESS { return nil, decode("failed to get device count", code) } - var uuids []string + uuids := make(map[string]mode) + for i := 0; i < int(count); i++ { device, code := nvml.DeviceGetHandleByIndex(int(i)) if code != nvml.SUCCESS { @@ -66,7 +68,7 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) { return nil, decode("failed to get device %d uuid", code) } - uuids = append(uuids, uuid) + uuids[uuid] = normal continue } if code != nvml.SUCCESS { @@ -78,6 +80,11 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) { return nil, decode("failed to get device MIG device count", code) } + uuid, code := nvml.DeviceGetUUID(device) + if code == nvml.SUCCESS { + uuids[uuid] = parent + } + for j := 0; j < int(migCount); j++ { migDevice, code := nvml.DeviceGetMigDeviceHandleByIndex(device, int(j)) if code == nvml.ERROR_NOT_FOUND || code == nvml.ERROR_INVALID_ARGUMENT { @@ -91,7 +98,7 @@ func (n *nvmlDriver) ListDeviceUUIDs() ([]string, error) { if code != nvml.SUCCESS { return nil, decode(fmt.Sprintf("failed to get mig device uuid %d", j), code) } - uuids = append(uuids, uuid) + uuids[uuid] = mig } } diff --git a/nvml/shared.go b/nvml/shared.go index 2d675f1..17596a2 100644 --- a/nvml/shared.go +++ b/nvml/shared.go @@ -10,6 +10,14 @@ var ( UnavailableLib = errors.New("could not load NVML library") ) +type mode int + +const ( + normal mode = iota + parent + mig +) + // nvmlDriver implements NvmlDriver // Users are required to call Initialize method before using any other methods type nvmlDriver struct{} @@ -19,7 +27,7 @@ type NvmlDriver interface { Initialize() error Shutdown() error SystemDriverVersion() (string, error) - ListDeviceUUIDs() ([]string, error) + ListDeviceUUIDs() (map[string]mode, error) DeviceInfoByUUID(string) (*DeviceInfo, error) DeviceInfoAndStatusByUUID(string) (*DeviceInfo, *DeviceStatus, error) }