Skip to content

Commit

Permalink
manual testing
Browse files Browse the repository at this point in the history
  • Loading branch information
letmutx committed May 21, 2022
1 parent d610e22 commit 2fa1764
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 38 deletions.
53 changes: 24 additions & 29 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ var (
Factory: func(ctx context.Context, l log.Logger) interface{} { return NewPlugin(ctx, l) },
}

// configSpec is the specification of the schema for this plugin's config.
// this is used to validate the HCL for the plugin provided
// as part of the client config:
// https://www.nomadproject.io/docs/configuration/plugin.html
// options are here:
// https://github.com/hashicorp/nomad/blob/v0.10.0/plugins/shared/hclspec/hcl_spec.proto
// configSpec is the specification of the plugin's configuration
configSpec = hclspec.NewObject(map[string]*hclspec.Spec{
"enabled": hclspec.NewDefault(
hclspec.NewAttr("enabled", "bool", false),
Expand All @@ -72,50 +65,55 @@ var (
})
)

type NvidiaDevicePlugin = device.DevicePlugin

// Config contains configuration information for the plugin.
type Config struct {
Vgpus int `codec:"vgpus"`
}

// NvidiaVgpuDevice contains a skeleton for most of the implementation of a
// device plugin.
type NvidiaVgpuDevice struct {
*nvidia.NvidiaDevice
// NvidiaVgpuPlugin is a wrapper for NvidiaDevicePlugin
// It handles fingerprinting, stats and allocation of virtual devices
type NvidiaVgpuPlugin struct {
NvidiaDevicePlugin
vgpus int

devices map[string]struct{}
deviceLock sync.RWMutex

log log.Logger
}

// NewPlugin returns a device plugin, used primarily by the main wrapper
//
// Plugin configuration isn't available yet, so there will typically be
// a limit to the initialization that can be performed at this point.
func NewPlugin(ctx context.Context, log log.Logger) *NvidiaVgpuDevice {
return &NvidiaVgpuDevice{
NvidiaDevice: nvidia.NewNvidiaDevice(ctx, log),
devices: map[string]struct{}{},
func NewPlugin(ctx context.Context, log log.Logger) *NvidiaVgpuPlugin {
return &NvidiaVgpuPlugin{
NvidiaDevicePlugin: nvidia.NewNvidiaDevice(ctx, log),
devices: map[string]struct{}{},
log: log,
}
}

// PluginInfo returns information describing the plugin.
//
// This is called during Nomad client startup, while discovering and loading
// plugins.
func (d *NvidiaVgpuDevice) PluginInfo() (*base.PluginInfoResponse, error) {
func (d *NvidiaVgpuPlugin) PluginInfo() (*base.PluginInfoResponse, error) {
return pluginInfo, nil
}

// ConfigSchema returns the configuration schema for the plugin.
//
// This is called during Nomad client startup, immediately before parsing
// plugin config and calling SetConfig
func (d *NvidiaVgpuDevice) ConfigSchema() (*hclspec.Spec, error) {
func (d *NvidiaVgpuPlugin) ConfigSchema() (*hclspec.Spec, error) {
return configSpec, nil
}

// SetConfig is called by the client to pass the configuration for the plugin.
func (d *NvidiaVgpuDevice) SetConfig(c *base.Config) (err error) {
func (d *NvidiaVgpuPlugin) SetConfig(c *base.Config) (err error) {
var config Config

// decode the plugin config
Expand All @@ -124,23 +122,20 @@ func (d *NvidiaVgpuDevice) SetConfig(c *base.Config) (err error) {
}

if config.Vgpus <= 0 {
return fmt.Errorf("invalid value for vgpus %q: %v", config.Vgpus, errors.New("must be >= 1"))
}

if err = d.NvidiaDevice.SetConfig(c); err != nil {
return err
return fmt.Errorf("invalid value for vgpus %q: %w", config.Vgpus, errors.New("must be >= 1"))
}
d.vgpus = config.Vgpus

return nil
return d.NvidiaDevicePlugin.SetConfig(c)
}

// Fingerprint streams detected devices.
// Messages should be emitted to the returned channel when there are changes
// to the devices or their health.
func (d *NvidiaVgpuDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) {
func (d *NvidiaVgpuPlugin) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) {
// Fingerprint returns a channel. The recommended way of organizing a plugin
// is to pass that into a long-running goroutine and return the channel immediately.
nvOut, err := d.NvidiaDevice.Fingerprint(ctx)
nvOut, err := d.NvidiaDevicePlugin.Fingerprint(ctx)
if err != nil {
return nil, err
}
Expand All @@ -151,11 +146,11 @@ func (d *NvidiaVgpuDevice) Fingerprint(ctx context.Context) (<-chan *device.Fing

// Stats streams statistics for the detected devices.
// Messages should be emitted to the returned channel on the specified interval.
func (d *NvidiaVgpuDevice) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) {
func (d *NvidiaVgpuPlugin) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) {
// Similar to Fingerprint, Stats returns a channel. The recommended way of
// organizing a plugin is to pass that into a long-running goroutine and
// return the channel immediately.
nvOut, err := d.NvidiaDevice.Stats(ctx, interval)
nvOut, err := d.NvidiaDevicePlugin.Stats(ctx, interval)
if err != nil {
return nil, err
}
Expand All @@ -175,7 +170,7 @@ func (e *reservationError) Error() string {
// Reserve returns information to the task driver on on how to mount the given devices.
// It may also perform any device-specific orchestration necessary to prepare the device
// for use. This is called in a pre-start hook on the client, before starting the workload.
func (d *NvidiaVgpuDevice) Reserve(deviceIDs []string) (*device.ContainerReservation, error) {
func (d *NvidiaVgpuPlugin) Reserve(deviceIDs []string) (*device.ContainerReservation, error) {
if len(deviceIDs) == 0 {
return &device.ContainerReservation{}, nil
}
Expand Down
3 changes: 2 additions & 1 deletion examples/config.hcl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
config {
some_required_boolean = true
enabled = true
vgpus = 8
}
12 changes: 6 additions & 6 deletions fingerprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// doFingerprint is the long-running goroutine that detects device changes
func (d *NvidiaVgpuDevice) doFingerprint(ctx context.Context, nvDevices <-chan *device.FingerprintResponse, virtDevices chan *device.FingerprintResponse) {
func (d *NvidiaVgpuPlugin) doFingerprint(ctx context.Context, nvDevices <-chan *device.FingerprintResponse, virtDevices chan *device.FingerprintResponse) {
defer close(virtDevices)

for {
Expand All @@ -21,15 +21,15 @@ func (d *NvidiaVgpuDevice) doFingerprint(ctx context.Context, nvDevices <-chan *
}
}

func (d *NvidiaVgpuDevice) nvDeviceToVirtDevices(ctx context.Context, nvFpr *device.FingerprintResponse) *device.FingerprintResponse {
func (d *NvidiaVgpuPlugin) nvDeviceToVirtDevices(ctx context.Context, nvFpr *device.FingerprintResponse) *device.FingerprintResponse {
if nvFpr.Error != nil {
return nvFpr
}
var fpr device.FingerprintResponse

d.deviceLock.Lock()
defer d.deviceLock.Unlock()

var devices []*device.DeviceGroup

for _, nvDeviceGroup := range nvFpr.Devices {
devGroup := &device.DeviceGroup{
Name: nvDeviceGroup.Name,
Expand All @@ -51,8 +51,8 @@ func (d *NvidiaVgpuDevice) nvDeviceToVirtDevices(ctx context.Context, nvFpr *dev
}
}

fpr.Devices = append(fpr.Devices, devGroup)
devices = append(devices, devGroup)
}

return &fpr
return device.NewFingerprint(devices...)
}
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878/go.mod h1:3AMJUQh
github.com/armon/go-metrics v0.3.4 h1:Xqf+7f2Vhl9tsqDYmXhnXInUdcrtgpRNpIA15/uldSc=
github.com/armon/go-metrics v0.3.4/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4QAOwNTFc=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0=
github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM=
Expand All @@ -106,6 +107,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d h1:xDfNPAt8lFiC1UJrqV3uuy861HCTo708pDMbjHHdCas=
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4=
github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
Expand Down Expand Up @@ -522,6 +524,7 @@ github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3N
github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU=
github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
github.com/mitchellh/cli v1.1.0 h1:tEElEatulEHDeedTxwckzyYMA5c86fbmNIUL1hBIiTg=
github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI=
github.com/mitchellh/colorstring v0.0.0-20150917214807-8631ce90f286/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw=
github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
Expand Down Expand Up @@ -616,6 +619,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo=
github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s=
github.com/prometheus/client_golang v0.0.0-20180209125602-c332b6f63c06/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
Expand Down
4 changes: 2 additions & 2 deletions stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

// doStats is the long running goroutine that streams device statistics
func (d *NvidiaVgpuDevice) doStats(ctx context.Context, nvStats <-chan *device.StatsResponse, stats chan<- *device.StatsResponse) {
func (d *NvidiaVgpuPlugin) doStats(ctx context.Context, nvStats <-chan *device.StatsResponse, stats chan<- *device.StatsResponse) {
defer close(stats)

for {
Expand All @@ -21,7 +21,7 @@ func (d *NvidiaVgpuDevice) doStats(ctx context.Context, nvStats <-chan *device.S
}
}

func (d *NvidiaVgpuDevice) nvStatsToVirtstats(nvStats *device.StatsResponse) *device.StatsResponse {
func (d *NvidiaVgpuPlugin) nvStatsToVirtstats(nvStats *device.StatsResponse) *device.StatsResponse {
if nvStats.Error != nil {
return nvStats
}
Expand Down

0 comments on commit 2fa1764

Please sign in to comment.