diff --git a/device.go b/device.go index c0226ef..bec5bc0 100644 --- a/device.go +++ b/device.go @@ -45,13 +45,6 @@ var ( Factory: func(ctx context.Context, l log.Logger) interface{} { return NewPlugin(ctx, l) }, } - // configSpec is the specification of the schema for this plugin's config. - // this is used to validate the HCL for the plugin provided - // as part of the client config: - // https://www.nomadproject.io/docs/configuration/plugin.html - // options are here: - // https://github.com/hashicorp/nomad/blob/v0.10.0/plugins/shared/hclspec/hcl_spec.proto - // configSpec is the specification of the plugin's configuration configSpec = hclspec.NewObject(map[string]*hclspec.Spec{ "enabled": hclspec.NewDefault( hclspec.NewAttr("enabled", "bool", false), @@ -72,29 +65,34 @@ var ( }) ) +type NvidiaDevicePlugin = device.DevicePlugin + // Config contains configuration information for the plugin. type Config struct { Vgpus int `codec:"vgpus"` } -// NvidiaVgpuDevice contains a skeleton for most of the implementation of a -// device plugin. -type NvidiaVgpuDevice struct { - *nvidia.NvidiaDevice +// NvidiaVgpuPlugin is a wrapper for NvidiaDevicePlugin +// It handles fingerprinting, stats and allocation of virtual devices +type NvidiaVgpuPlugin struct { + NvidiaDevicePlugin vgpus int devices map[string]struct{} deviceLock sync.RWMutex + + log log.Logger } // NewPlugin returns a device plugin, used primarily by the main wrapper // // Plugin configuration isn't available yet, so there will typically be // a limit to the initialization that can be performed at this point. -func NewPlugin(ctx context.Context, log log.Logger) *NvidiaVgpuDevice { - return &NvidiaVgpuDevice{ - NvidiaDevice: nvidia.NewNvidiaDevice(ctx, log), - devices: map[string]struct{}{}, +func NewPlugin(ctx context.Context, log log.Logger) *NvidiaVgpuPlugin { + return &NvidiaVgpuPlugin{ + NvidiaDevicePlugin: nvidia.NewNvidiaDevice(ctx, log), + devices: map[string]struct{}{}, + log: log, } } @@ -102,7 +100,7 @@ func NewPlugin(ctx context.Context, log log.Logger) *NvidiaVgpuDevice { // // This is called during Nomad client startup, while discovering and loading // plugins. -func (d *NvidiaVgpuDevice) PluginInfo() (*base.PluginInfoResponse, error) { +func (d *NvidiaVgpuPlugin) PluginInfo() (*base.PluginInfoResponse, error) { return pluginInfo, nil } @@ -110,12 +108,12 @@ func (d *NvidiaVgpuDevice) PluginInfo() (*base.PluginInfoResponse, error) { // // This is called during Nomad client startup, immediately before parsing // plugin config and calling SetConfig -func (d *NvidiaVgpuDevice) ConfigSchema() (*hclspec.Spec, error) { +func (d *NvidiaVgpuPlugin) ConfigSchema() (*hclspec.Spec, error) { return configSpec, nil } // SetConfig is called by the client to pass the configuration for the plugin. -func (d *NvidiaVgpuDevice) SetConfig(c *base.Config) (err error) { +func (d *NvidiaVgpuPlugin) SetConfig(c *base.Config) (err error) { var config Config // decode the plugin config @@ -124,23 +122,20 @@ func (d *NvidiaVgpuDevice) SetConfig(c *base.Config) (err error) { } if config.Vgpus <= 0 { - return fmt.Errorf("invalid value for vgpus %q: %v", config.Vgpus, errors.New("must be >= 1")) - } - - if err = d.NvidiaDevice.SetConfig(c); err != nil { - return err + return fmt.Errorf("invalid value for vgpus %q: %w", config.Vgpus, errors.New("must be >= 1")) } + d.vgpus = config.Vgpus - return nil + return d.NvidiaDevicePlugin.SetConfig(c) } // Fingerprint streams detected devices. // Messages should be emitted to the returned channel when there are changes // to the devices or their health. -func (d *NvidiaVgpuDevice) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) { +func (d *NvidiaVgpuPlugin) Fingerprint(ctx context.Context) (<-chan *device.FingerprintResponse, error) { // Fingerprint returns a channel. The recommended way of organizing a plugin // is to pass that into a long-running goroutine and return the channel immediately. - nvOut, err := d.NvidiaDevice.Fingerprint(ctx) + nvOut, err := d.NvidiaDevicePlugin.Fingerprint(ctx) if err != nil { return nil, err } @@ -151,11 +146,11 @@ func (d *NvidiaVgpuDevice) Fingerprint(ctx context.Context) (<-chan *device.Fing // Stats streams statistics for the detected devices. // Messages should be emitted to the returned channel on the specified interval. -func (d *NvidiaVgpuDevice) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) { +func (d *NvidiaVgpuPlugin) Stats(ctx context.Context, interval time.Duration) (<-chan *device.StatsResponse, error) { // Similar to Fingerprint, Stats returns a channel. The recommended way of // organizing a plugin is to pass that into a long-running goroutine and // return the channel immediately. - nvOut, err := d.NvidiaDevice.Stats(ctx, interval) + nvOut, err := d.NvidiaDevicePlugin.Stats(ctx, interval) if err != nil { return nil, err } @@ -175,7 +170,7 @@ func (e *reservationError) Error() string { // Reserve returns information to the task driver on on how to mount the given devices. // It may also perform any device-specific orchestration necessary to prepare the device // for use. This is called in a pre-start hook on the client, before starting the workload. -func (d *NvidiaVgpuDevice) Reserve(deviceIDs []string) (*device.ContainerReservation, error) { +func (d *NvidiaVgpuPlugin) Reserve(deviceIDs []string) (*device.ContainerReservation, error) { if len(deviceIDs) == 0 { return &device.ContainerReservation{}, nil } diff --git a/examples/config.hcl b/examples/config.hcl index be88c76..8282754 100644 --- a/examples/config.hcl +++ b/examples/config.hcl @@ -1,3 +1,4 @@ config { - some_required_boolean = true + enabled = true + vgpus = 8 } diff --git a/fingerprint.go b/fingerprint.go index 27b15c9..d88d2d5 100644 --- a/fingerprint.go +++ b/fingerprint.go @@ -8,7 +8,7 @@ import ( ) // doFingerprint is the long-running goroutine that detects device changes -func (d *NvidiaVgpuDevice) doFingerprint(ctx context.Context, nvDevices <-chan *device.FingerprintResponse, virtDevices chan *device.FingerprintResponse) { +func (d *NvidiaVgpuPlugin) doFingerprint(ctx context.Context, nvDevices <-chan *device.FingerprintResponse, virtDevices chan *device.FingerprintResponse) { defer close(virtDevices) for { @@ -21,15 +21,15 @@ func (d *NvidiaVgpuDevice) doFingerprint(ctx context.Context, nvDevices <-chan * } } -func (d *NvidiaVgpuDevice) nvDeviceToVirtDevices(ctx context.Context, nvFpr *device.FingerprintResponse) *device.FingerprintResponse { +func (d *NvidiaVgpuPlugin) nvDeviceToVirtDevices(ctx context.Context, nvFpr *device.FingerprintResponse) *device.FingerprintResponse { if nvFpr.Error != nil { return nvFpr } - var fpr device.FingerprintResponse - d.deviceLock.Lock() defer d.deviceLock.Unlock() + var devices []*device.DeviceGroup + for _, nvDeviceGroup := range nvFpr.Devices { devGroup := &device.DeviceGroup{ Name: nvDeviceGroup.Name, @@ -51,8 +51,8 @@ func (d *NvidiaVgpuDevice) nvDeviceToVirtDevices(ctx context.Context, nvFpr *dev } } - fpr.Devices = append(fpr.Devices, devGroup) + devices = append(devices, devGroup) } - return &fpr + return device.NewFingerprint(devices...) } diff --git a/go.sum b/go.sum index 966f265..0319ee2 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,7 @@ github.com/armon/go-metrics v0.0.0-20190430140413-ec5e00d3c878/go.mod h1:3AMJUQh github.com/armon/go-metrics v0.3.4 h1:Xqf+7f2Vhl9tsqDYmXhnXInUdcrtgpRNpIA15/uldSc= github.com/armon/go-metrics v0.3.4/go.mod h1:4O98XIr/9W0sxpJ8UaYkvjk10Iff7SnFrb4QAOwNTFc= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/armon/go-radix v1.0.0 h1:F4z6KzEeeQIMeLFa97iZU6vupzoecKdU5TX24SNppXI= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM= @@ -106,6 +107,7 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d h1:xDfNPAt8lFiC1UJrqV3uuy861HCTo708pDMbjHHdCas= github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4= +github.com/bgentry/speakeasy v0.1.0 h1:ByYyxL9InA1OWqxJqqp2A5pYHUrCiAL6K3J+LKSsQkY= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA= github.com/blang/semver v3.1.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= @@ -522,6 +524,7 @@ github.com/miekg/dns v1.1.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3N github.com/miekg/dns v1.1.26 h1:gPxPSwALAeHJSjarOs00QjVdV9QoBvc1D2ujQUr5BzU= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/cli v1.1.0 h1:tEElEatulEHDeedTxwckzyYMA5c86fbmNIUL1hBIiTg= github.com/mitchellh/cli v1.1.0/go.mod h1:xcISNoH86gajksDmfB23e/pu+B+GeFRMYmoHXxx3xhI= github.com/mitchellh/colorstring v0.0.0-20150917214807-8631ce90f286/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= @@ -616,6 +619,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/posener/complete v1.2.3 h1:NP0eAhjcjImqslEwo/1hq7gpajME0fTLTezBKDqfXqo= github.com/posener/complete v1.2.3/go.mod h1:WZIdtGGp+qx0sLrYKtIRAruyNpv6hFCicSgv7Sy7s/s= github.com/prometheus/client_golang v0.0.0-20180209125602-c332b6f63c06/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= diff --git a/stats.go b/stats.go index c8b4e51..83a5ba3 100644 --- a/stats.go +++ b/stats.go @@ -8,7 +8,7 @@ import ( ) // doStats is the long running goroutine that streams device statistics -func (d *NvidiaVgpuDevice) doStats(ctx context.Context, nvStats <-chan *device.StatsResponse, stats chan<- *device.StatsResponse) { +func (d *NvidiaVgpuPlugin) doStats(ctx context.Context, nvStats <-chan *device.StatsResponse, stats chan<- *device.StatsResponse) { defer close(stats) for { @@ -21,7 +21,7 @@ func (d *NvidiaVgpuDevice) doStats(ctx context.Context, nvStats <-chan *device.S } } -func (d *NvidiaVgpuDevice) nvStatsToVirtstats(nvStats *device.StatsResponse) *device.StatsResponse { +func (d *NvidiaVgpuPlugin) nvStatsToVirtstats(nvStats *device.StatsResponse) *device.StatsResponse { if nvStats.Error != nil { return nvStats }