From df7e4fae47d9a220ec98f602ef87d0388649ee41 Mon Sep 17 00:00:00 2001 From: Christoph Ostarek Date: Fri, 8 Sep 2023 14:47:05 +0200 Subject: [PATCH] pillar: introduce usbmanager a new microservice that dynamically passes through usb devices to qemu vms Signed-off-by: Christoph Ostarek --- pkg/pillar/cmd/domainmgr/domainmgr.go | 1 + pkg/pillar/cmd/usbmanager/ruleengine.go | 70 +++++ pkg/pillar/cmd/usbmanager/ruleengine_test.go | 262 ++++++++++++++++ pkg/pillar/cmd/usbmanager/rules.go | 232 ++++++++++++++ pkg/pillar/cmd/usbmanager/rules_test.go | 39 +++ pkg/pillar/cmd/usbmanager/scanusb.go | 195 ++++++++++++ pkg/pillar/cmd/usbmanager/scanusb_test.go | 45 +++ pkg/pillar/cmd/usbmanager/subscriptions.go | 228 ++++++++++++++ pkg/pillar/cmd/usbmanager/usbcontroller.go | 255 ++++++++++++++++ .../cmd/usbmanager/usbcontroller_test.go | 288 ++++++++++++++++++ pkg/pillar/cmd/usbmanager/usbdevice.go | 58 ++++ pkg/pillar/cmd/usbmanager/usbevent.go | 82 +++++ pkg/pillar/cmd/usbmanager/usbmanager.go | 101 ++++++ pkg/pillar/cmd/usbmanager/usbmanager_test.go | 11 + pkg/pillar/cmd/usbmanager/usbpassthrough.go | 113 +++++++ pkg/pillar/hypervisor/kvm.go | 16 +- pkg/pillar/hypervisor/qmp.go | 14 + pkg/pillar/pubsub/util.go | 1 + pkg/pillar/scripts/device-steps.sh | 2 +- pkg/pillar/zedbox/zedbox.go | 2 + 20 files changed, 2013 insertions(+), 2 deletions(-) create mode 100644 pkg/pillar/cmd/usbmanager/ruleengine.go create mode 100644 pkg/pillar/cmd/usbmanager/ruleengine_test.go create mode 100644 pkg/pillar/cmd/usbmanager/rules.go create mode 100644 pkg/pillar/cmd/usbmanager/rules_test.go create mode 100644 pkg/pillar/cmd/usbmanager/scanusb.go create mode 100644 pkg/pillar/cmd/usbmanager/scanusb_test.go create mode 100644 pkg/pillar/cmd/usbmanager/subscriptions.go create mode 100644 pkg/pillar/cmd/usbmanager/usbcontroller.go create mode 100644 pkg/pillar/cmd/usbmanager/usbcontroller_test.go create mode 100644 pkg/pillar/cmd/usbmanager/usbdevice.go create mode 100644 pkg/pillar/cmd/usbmanager/usbevent.go create mode 100644 pkg/pillar/cmd/usbmanager/usbmanager.go create mode 100644 pkg/pillar/cmd/usbmanager/usbmanager_test.go create mode 100644 pkg/pillar/cmd/usbmanager/usbpassthrough.go diff --git a/pkg/pillar/cmd/domainmgr/domainmgr.go b/pkg/pillar/cmd/domainmgr/domainmgr.go index 5fdff4a6760..dff759a41fb 100644 --- a/pkg/pillar/cmd/domainmgr/domainmgr.go +++ b/pkg/pillar/cmd/domainmgr/domainmgr.go @@ -138,6 +138,7 @@ var currentHypervisorMutex sync.Mutex var logger *logrus.Logger var log *base.LogObject +// CurrentHypervisor returns the current hypervisor func CurrentHypervisor() hypervisor.Hypervisor { currentHypervisorMutex.Lock() hv := currentHypervisor diff --git a/pkg/pillar/cmd/usbmanager/ruleengine.go b/pkg/pillar/cmd/usbmanager/ruleengine.go new file mode 100644 index 00000000000..718e152ba97 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/ruleengine.go @@ -0,0 +1,70 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "fmt" +) + +type nullObjectPassthroughRule struct { + passthroughRuleVMBase +} + +func (pr *nullObjectPassthroughRule) priority() uint8 { + return 0 +} +func (pr *nullObjectPassthroughRule) evaluate(_ usbdevice) passthroughAction { + return passthroughNo +} +func (pr *nullObjectPassthroughRule) String() string { + return "" +} + +type ruleEngine struct { + rules map[string]passthroughRule +} + +func newRuleEngine() *ruleEngine { + var re ruleEngine + + re.rules = make(map[string]passthroughRule) + + return &re +} + +func (re *ruleEngine) delRule(pr passthroughRule) { + delete(re.rules, pr.String()) +} + +func (re *ruleEngine) addRule(pr passthroughRule) { + re.rules[pr.String()] = pr +} + +func (re *ruleEngine) apply(ud usbdevice) *virtualmachine { + var maxRule passthroughRule + maxRule = &nullObjectPassthroughRule{} + + for _, r := range re.rules { + if r.evaluate(ud) == passthroughForbid { + return nil + } + if r.evaluate(ud) == passthroughDo { + if r.priority() > maxRule.priority() { + maxRule = r + } + } + } + + return maxRule.virtualMachine() +} + +func (re *ruleEngine) String() string { + var ret string + + ret = fmt.Sprintf("Rule Engine Rules (%d): |", len(re.rules)) + for _, rule := range re.rules { + ret += fmt.Sprintf("%s|", rule) + } + + return ret +} diff --git a/pkg/pillar/cmd/usbmanager/ruleengine_test.go b/pkg/pillar/cmd/usbmanager/ruleengine_test.go new file mode 100644 index 00000000000..10a66f6f8b8 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/ruleengine_test.go @@ -0,0 +1,262 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "testing" +) + +func TestOverwriteRule(t *testing.T) { + re := newRuleEngine() + + pci1 := pciPassthroughRule{pciAddress: "00:02.0"} + re.addRule(&pci1) + + pci2 := pciPassthroughRule{pciAddress: "00:02.0"} + re.addRule(&pci2) + + if len(re.rules) != 1 { + t.Fatalf("rule overwriting failed") + } +} + +func TestBlockedByPCIPassthrough(t *testing.T) { + re := newRuleEngine() + + pci := pciPassthroughRule{pciAddress: "00:02.0"} + re.addRule(&pci) + + ud := usbdevice{ + usbControllerPCIAddress: pci.pciAddress, // conflicts with pci rule + busnum: 01, + devnum: 02, + portnum: "2", + } + vm := virtualmachine{} + usb := usbPortPassthroughRule{ud: ud} + usb.vm = &vm + + re.addRule(&usb) + + connectVM := re.apply(ud) + + if connectVM != nil { + t.Fatalf("usb passthrough should be blocked by pci passthrough, but got connected vm") + } +} + +func TestPortOverDevPrecedence(t *testing.T) { + re := newRuleEngine() + + ud := usbdevice{ + usbControllerPCIAddress: "00:02.0", + busnum: 01, + devnum: 3, + portnum: "3.1", + vendorID: 5, + productID: 6, + } + + usbPortRule := usbPortPassthroughRule{ud: ud} + usbPortRule.vm = &virtualmachine{} + + usbDevRule := usbDevicePassthroughRule{ud: ud} + + re.addRule(&usbPortRule) + re.addRule(&usbDevRule) + + connectVM := re.apply(ud) + + if connectVM == nil { + t.Fatalf("usb passthrough should work, but got no connected vm") + } +} + +func TestUSBWithoutPCICard(t *testing.T) { + re := newRuleEngine() + + ud := usbdevice{ + busnum: 01, + devnum: 2, + portnum: "2", + vendorID: 5, + productID: 6, + } + + usbPortRule := usbPortPassthroughRule{ud: ud} + usbPortRule.vm = &virtualmachine{} + + re.addRule(&usbPortRule) + + connectVM := re.apply(ud) + + if connectVM == nil { + t.Fatalf("pci-less usb passthrough fails") + } + +} + +func TestPluginWrongPCICard(t *testing.T) { + re := newRuleEngine() + + ud := usbdevice{ + usbControllerPCIAddress: "00:02.0", + busnum: 01, + portnum: "2", + devnum: 2, + vendorID: 5, + productID: 6, + } + + usbRule := usbPortPassthroughRule{ + ud: usbdevice{ + busnum: 01, + devnum: 02, + usbControllerPCIAddress: "00:03.0", + }, + } + usbRule.vm = &virtualmachine{qmpSocketPath: "/vm/with/usb/passthrough"} + re.addRule(&usbRule) + + vm := re.apply(ud) + if vm != nil { + t.Fatal("ud should not be passed as parent pci addresses are different") + } + + t.Log(re.String()) +} + +func TestEmptyParentPCIAddress(t *testing.T) { + re := newRuleEngine() + + ud1 := usbdevice{ + usbControllerPCIAddress: "00:02.0", + busnum: 01, + devnum: 02, + portnum: "2", + vendorID: 5, + productID: 6, + } + ud2 := usbdevice{ + usbControllerPCIAddress: "00:03.0", + busnum: 02, + devnum: 02, + portnum: "3", + vendorID: 5, + productID: 6, + } + + pciRule := pciPassthroughRule{ + pciAddress: "00:02.0", + } + pciRule.vm = &virtualmachine{qmpSocketPath: "/vm/with/pci/passthrough"} + re.addRule(&pciRule) + + usbRule := usbPortPassthroughRule{ + ud: usbdevice{ + busnum: 02, + devnum: 02, + portnum: "3", + usbControllerPCIAddress: "", + }, + } + usbRule.vm = &virtualmachine{qmpSocketPath: "/vm/with/usb/passthrough"} + re.addRule(&usbRule) + + ud1VM := re.apply(ud1) + if ud1VM != nil { + t.Fatal("ud1 should not be passed as underlying PCI device is passed through") + } + + ud2VM := re.apply(ud2) + if ud2VM == nil { + t.Fatal("ud2 should be passed through") + } + + t.Log(re.String()) +} + +func FuzzRuleEngine(f *testing.F) { + + f.Fuzz(func(t *testing.T, + // usb device passthrough rule + parentPCIAddressRule1 string, + busnumRule1 uint16, + devnumRule1 uint16, + vendorIdRule1 uint32, + productIdRule1 uint32, + // usb plug passthrough rule + parentPCIAddressRule2 string, + busnumRule2 uint16, + devnumRule2 uint16, + vendorIdRule2 uint32, + productIdRule2 uint32, + // pci passthrough rule + parentPCIAddressRule3 string, + // actual usb device + parentPCIAddress string, + busnum uint16, + devnum uint16, + vendorId uint32, + productId uint32, + ) { + re := newRuleEngine() + udRule1 := usbdevice{ + busnum: busnumRule1, + devnum: devnumRule1, + vendorID: vendorIdRule1, + productID: productIdRule1, + usbControllerPCIAddress: parentPCIAddressRule1, + } + rule1 := usbDevicePassthroughRule{ud: udRule1} + rule1.vm = &virtualmachine{ + qmpSocketPath: "/vm1", + } + + udRule2 := usbdevice{ + busnum: busnumRule2, + devnum: devnumRule2, + vendorID: vendorIdRule2, + productID: productIdRule2, + usbControllerPCIAddress: parentPCIAddressRule2, + } + rule2 := usbPortPassthroughRule{ud: udRule2} + rule2.vm = &virtualmachine{ + qmpSocketPath: "/vm2", + } + + rule3 := pciPassthroughRule{ + pciAddress: parentPCIAddressRule3, + } + + ud := usbdevice{ + busnum: busnum, + devnum: devnum, + vendorID: vendorId, + productID: productId, + usbControllerPCIAddress: parentPCIAddress, + } + re.addRule(&rule1) + re.addRule(&rule2) + re.addRule(&rule3) + + connectVM := re.apply(ud) + if connectVM == nil { + return + } + if rule3.pciAddress == ud.usbControllerPCIAddress { + t.Fatal("passthrough should not work as it is blocked by pci passthrough") + } + // check that if udRule1 and udRule2 apply, we get the one with the higher priority, i.e. udRule2 + // which means, as long as udRule2 applies, we should get udRule2.vm + reUdRule2 := newRuleEngine() + reUdRule2.addRule(&rule2) + connectVMUdRule := reUdRule2.apply(ud) + if connectVMUdRule != nil { + if connectVMUdRule.qmpSocketPath != "/vm2" { + t.Fatal("usb plug rule applies, but rule with higher precedence has been found") + } + } + + }) +} diff --git a/pkg/pillar/cmd/usbmanager/rules.go b/pkg/pillar/cmd/usbmanager/rules.go new file mode 100644 index 00000000000..0381d6132e2 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/rules.go @@ -0,0 +1,232 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +type passthroughAction uint8 + +const ( + // this rule applies + passthroughDo = 0 + // this rule does not apply + passthroughNo = iota + // this rule forbids passthrough even if other rules apply + passthroughForbid = iota +) + +type virtualmachine struct { + qmpSocketPath string + adapters []string +} + +func newVirtualmachine(qmpSocketPath string, adapters []string) virtualmachine { + vm := virtualmachine{ + qmpSocketPath: qmpSocketPath, + adapters: adapters, + } + + if vm.adapters == nil { + vm.adapters = make([]string, 0) + } + + return vm +} + +func (vm *virtualmachine) addAdapter(adapter string) { + vm.adapters = append(vm.adapters, adapter) +} + +func (vm virtualmachine) String() string { + return fmt.Sprintf("vm-qmp: %s adapters: '%s'", vm.qmpSocketPath, strings.Join(vm.adapters, ", ")) +} + +type passthroughRule interface { + evaluate(ud usbdevice) passthroughAction + priority() uint8 + virtualMachine() *virtualmachine + setVirtualMachine(vm *virtualmachine) + String() string +} + +func (pr passthroughAction) String() string { + if pr == passthroughDo { + return "do passthrough" + } else if pr == passthroughNo { + return "no passthrough" + } else if pr == passthroughForbid { + return "forbid passthrough" + } + + return "" +} + +type passthroughRuleVMBase struct { + vm *virtualmachine +} + +func (pr *passthroughRuleVMBase) setVirtualMachine(vm *virtualmachine) { + pr.vm = vm +} + +func (pr *passthroughRuleVMBase) virtualMachine() *virtualmachine { + return pr.vm +} + +type pciPassthroughRule struct { + pciAddress string + passthroughRuleVMBase +} + +func (pr *pciPassthroughRule) String() string { + return fmt.Sprintf("PCI Passthrough Rule %s", pr.pciAddress) +} + +func (pr *pciPassthroughRule) evaluate(ud usbdevice) passthroughAction { + if ud.usbControllerPCIAddress == pr.pciAddress { + return passthroughForbid + } + + return passthroughNo +} + +func (pr *pciPassthroughRule) priority() uint8 { + return 0 +} + +type usbDevicePassthroughRule struct { + ud usbdevice + passthroughRuleVMBase +} + +func (udpr *usbDevicePassthroughRule) String() string { + return fmt.Sprintf("USB Device Passthrough Rule %s on pci %s", udpr.ud.vendorAndproductIDString(), udpr.ud.usbControllerPCIAddress) +} + +func (udpr *usbDevicePassthroughRule) priority() uint8 { + return 10 +} + +func (udpr *usbDevicePassthroughRule) evaluate(ud usbdevice) passthroughAction { + if udpr.ud.usbControllerPCIAddress != "" && udpr.ud.usbControllerPCIAddress != ud.usbControllerPCIAddress { + return passthroughNo + } + if udpr.ud.vendorID != ud.vendorID || + udpr.ud.productID != ud.productID { + return passthroughNo + } + + return passthroughDo +} + +type usbPortPassthroughRule struct { + ud usbdevice + passthroughRuleVMBase +} + +func (uppr *usbPortPassthroughRule) String() string { + return fmt.Sprintf("USB Port Passthrough Rule %s on pci %s", uppr.ud.busnumAndDevnumString(), uppr.ud.usbControllerPCIAddress) +} + +func (uppr *usbPortPassthroughRule) priority() uint8 { + return 20 +} + +func (uppr *usbPortPassthroughRule) evaluate(ud usbdevice) passthroughAction { + if uppr.ud.usbControllerPCIAddress != "" && uppr.ud.usbControllerPCIAddress != ud.usbControllerPCIAddress { + return passthroughNo + } + if uppr.ud.portnum != ud.portnum || + uppr.ud.busnum != ud.busnum { + return passthroughNo + } + + return passthroughDo +} + +type usbHubForbidPassthroughRule struct { + passthroughRuleVMBase +} + +func (uhfpr *usbHubForbidPassthroughRule) String() string { + return "usbHubForbidPassthroughRule" +} + +func (uhfpr *usbHubForbidPassthroughRule) priority() uint8 { + return 0 +} + +func (uhfpr *usbHubForbidPassthroughRule) evaluate(ud usbdevice) passthroughAction { + if strings.HasPrefix(ud.devicetype, "9/") { + log.Tracef("usb hub forwarding is forbidden - %+v", ud) + return passthroughForbid + } + + return passthroughNo +} + +func newUsbNetworkAdapterForbidPassthroughRule() usbNetworkAdapterForbidPassthroughRule { + unafpr := usbNetworkAdapterForbidPassthroughRule{} + unafpr.netDevPaths = unafpr.netDevPathsImpl + + return unafpr +} + +type usbNetworkAdapterForbidPassthroughRule struct { + netDevPaths func() []string + passthroughRuleVMBase +} + +func (unafpr *usbNetworkAdapterForbidPassthroughRule) String() string { + return "usbNetworkAdapterForbidPassthroughRule" +} + +func (unafpr *usbNetworkAdapterForbidPassthroughRule) priority() uint8 { + return 0 +} + +func (unafpr *usbNetworkAdapterForbidPassthroughRule) evaluate(ud usbdevice) passthroughAction { + netDevPaths := unafpr.netDevPaths() + + ueventDirname := filepath.Dir(ud.ueventFilePath) + "/" + for _, path := range netDevPaths { + if strings.HasPrefix(path, ueventDirname) { + log.Tracef("usb network adapter forwarding is forbidden - %+v", ud) + return passthroughForbid + } + } + + return passthroughNo +} + +func (*usbNetworkAdapterForbidPassthroughRule) netDevPathsImpl() []string { + netDir := filepath.Join(sysFSPath, "class", "net") + netDevfiles, err := os.ReadDir(netDir) + if err != nil { + panic(err) + } + + netDevPaths := make([]string, 0) + + for _, file := range netDevfiles { + // e.g. ../../devices/pci0000:00/0000:00:14.0/usb4/4-2/4-2.1/4-2.1:1.0/net/enp0s20f0u2u1/ + relPath, err := os.Readlink(filepath.Join(netDir, file.Name())) + if err != nil { + panic(err) + } + + // remove net/enp0s20f0u2u1/ and prefix with syfs dir + absPath, err := filepath.Abs(filepath.Join(netDir, relPath, "..", "..")) + if err != nil { + panic(err) + } + + netDevPaths = append(netDevPaths, absPath) + } + return netDevPaths +} diff --git a/pkg/pillar/cmd/usbmanager/rules_test.go b/pkg/pillar/cmd/usbmanager/rules_test.go new file mode 100644 index 00000000000..04ada521d3b --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/rules_test.go @@ -0,0 +1,39 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "testing" +) + +func TestUsbNetworkAdapterForbidPassthroughRule(t *testing.T) { + usbNetworkAdapterForbidPassthroughRule := usbNetworkAdapterForbidPassthroughRule{} + + usbNetworkAdapterForbidPassthroughRule.netDevPaths = func() []string { + return []string{"/sys/devices/pci0000:00/0000:00:14.0/usb4/4-2/4-2.1/4-2.1:1.0"} + } + + ud := usbdevice{} + + ud.ueventFilePath = "/sys/devices/pci0000:00/0000:00:14.0/usb4/4-2/4-2.1/" + + if usbNetworkAdapterForbidPassthroughRule.evaluate(ud) != passthroughForbid { + t.Fatalf("passthrough should be forbidden, but isn't") + } +} + +func TestUsbNetworkAdapterAllowPassthroughRule(t *testing.T) { + usbNetworkAdapterForbidPassthroughRule := usbNetworkAdapterForbidPassthroughRule{} + + usbNetworkAdapterForbidPassthroughRule.netDevPaths = func() []string { + return []string{"/sys/devices/pci0000:00/0000:00:14.0/usb4/4-2/4-2.11/4-2.1:1.0"} + } + + ud := usbdevice{} + + ud.ueventFilePath = "/sys/devices/pci0000:00/0000:00:14.0/usb4/4-2/4-2.1/" + + if usbNetworkAdapterForbidPassthroughRule.evaluate(ud) == passthroughForbid { + t.Fatalf("passthrough should not be forbidden (port 1 versus port 11), but it is") + } +} diff --git a/pkg/pillar/cmd/usbmanager/scanusb.go b/pkg/pillar/cmd/usbmanager/scanusb.go new file mode 100644 index 00000000000..527bcc11dd3 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/scanusb.go @@ -0,0 +1,195 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/lf-edge/eve/pkg/pillar/types" +) + +func trimSysPath(path string) string { + relPath := strings.TrimLeft(path, sysFSPath) + + return relPath +} + +func extractPCIaddress(path string) string { + re := regexp.MustCompile(`\/?devices\/pci[\d:.]*\/(\d{4}:[a-f\d:\.]+)\/`) + + relPath := trimSysPath(path) + matches := re.FindStringSubmatch(relPath) + if len(matches) != 2 { + return "" + } + pciAddress := matches[1] + + return pciAddress +} + +func extractUSBPort(path string) string { + _, port, err := types.ExtractUSBBusnumPort(path) + + if err != nil { + log.Warn(err) + } + + return port +} + +func walkUSBPorts() []*usbdevice { + uds := make([]*usbdevice, 0) + + inSysPath := filepath.Join("bus", "usb", "devices") + usbDevicesPath := filepath.Join(sysFSPath, inSysPath) + + files, err := os.ReadDir(usbDevicesPath) + if err != nil { + log.Fatal(err) + } + + re := regexp.MustCompile(`^\d+-\d+`) + for _, file := range files { + if len(file.Name()) == 0 { + continue + } + + if !re.Match([]byte(file.Name())) { + continue + } + + relPortPath := filepath.Join(usbDevicesPath, file.Name()) + //fmt.Printf("%s -> %s\n", vals[0], vals[1]) + ud := usbDeviceFromSysPath(relPortPath) + if ud != nil { + uds = append(uds, ud) + } + } + + return uds +} + +func usbDeviceFromSysPath(relPortPath string) *usbdevice { + portPath, err := os.Readlink(relPortPath) + if err != nil { + fmt.Printf("err: %+v\n", err) + return nil + } + + inSysPath := filepath.Join("bus", "usb", "devices") + usbDevicesPath := filepath.Join(sysFSPath, inSysPath) + portPath = filepath.Join(usbDevicesPath, portPath) + + ueventFilePath := filepath.Join(portPath, "uevent") + return ueventFile2usbDevice(ueventFilePath) +} + +func ueventFile2usbDevice(ueventFilePath string) *usbdevice { + ueventFp, err := os.Open(ueventFilePath) + if err != nil { + return nil + } + defer ueventFp.Close() + + var busnum uint16 + var devnum uint16 + var vendorID uint32 + var productID uint32 + var product string + var devicetype string + + busnumSet := false + devnumSet := false + productSet := false + + sc := bufio.NewScanner(ueventFp) + for sc.Scan() { + vals := strings.SplitN(sc.Text(), "=", 2) + if len(vals) != 2 { + continue + } + + if vals[1] == "" { + continue + } + + if vals[0] == "BUSNUM" { + val64, err := strconv.ParseUint(vals[1], 10, 16) + if err != nil { + panic(err) + } + busnum = uint16(val64) + busnumSet = true + } + if vals[0] == "DEVNUM" { + val64, err := strconv.ParseUint(vals[1], 10, 16) + if err != nil { + panic(err) + } + devnum = uint16(val64) + devnumSet = true + } + if vals[0] == "PRODUCT" { + product = vals[1] + vendorID, productID = parseProductString(product) + if vendorID != 0 || productID != 0 { + productSet = true + } + } + if vals[0] == "TYPE" { + devicetype = vals[1] + } + } + + if !busnumSet || !devnumSet || !productSet { + return nil + } + + pciAddress := extractPCIaddress(ueventFilePath) + if pciAddress == "" { + return nil + } + + portnum := extractUSBPort(ueventFilePath) + + ud := usbdevice{ + busnum: busnum, + devnum: devnum, + portnum: portnum, + vendorID: vendorID, + productID: productID, + devicetype: devicetype, + usbControllerPCIAddress: pciAddress, + ueventFilePath: filepath.Clean(ueventFilePath), + } + + return &ud +} + +func parseProductString(product string) (uint32, uint32) { + var vendorID uint32 + var productID uint32 + + vals := strings.SplitN(product, "/", 3) + if len(vals) < 2 { + return 0, 0 + } + + for i, v := range []*uint32{&vendorID, &productID} { + + val64, err := strconv.ParseUint(vals[i], 16, 32) + if err != nil { + return 0, 0 + } + + *v = uint32(val64) + } + + return vendorID, productID +} diff --git a/pkg/pillar/cmd/usbmanager/scanusb_test.go b/pkg/pillar/cmd/usbmanager/scanusb_test.go new file mode 100644 index 00000000000..b1bf1023cd7 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/scanusb_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "testing" +) + +func TestExtractUSBPort(t *testing.T) { + t.Parallel() + + table := []struct { + path string + port string + }{ + {"/sys/devices/pci0000:00/0000:00:14.0/usb3/3-6/uevent", "6"}, + {"/sys/devices/pci0000:00/0000:00:14.0/usb3/3-3/3-3.1/uevent", "3.1"}, + } + + for _, test := range table { + port := extractUSBPort(test.path) + if port != test.port { + t.Fatalf("expected port %s but got %s, path is %s", test.port, port, test.path) + } + } +} + +func TestExtractPCIAddress(t *testing.T) { + // /sys/devices/platform/soc@0/32f10108.usb/38200000.dwc3/xhci-hcd.1.auto/usb3/3-1/3-1.4/3-1.4:1.0/uevent + table := []struct { + path string + pciAddress string + }{ + {"/sys/devices/pci0000:00/0000:00:14.0/usb3/3-6/uevent", "0000:00:14.0"}, + {"/sys/devices/pci0000:00/0000:00:14.0/usb3/3-3/3-3.1/uevent", "0000:00:14.0"}, + {"/sys/devices/platform/soc@0/32f10108.usb/38200000.dwc3/xhci-hcd.1.auto/usb3/3-1/3-1.4/3-1.4:1.0/uevent", ""}, + } + + for _, test := range table { + port := extractPCIaddress(test.path) + if port != test.pciAddress { + t.Fatalf("expected port %s but got %s, path is %s", test.pciAddress, port, test.path) + } + } +} diff --git a/pkg/pillar/cmd/usbmanager/subscriptions.go b/pkg/pillar/cmd/usbmanager/subscriptions.go new file mode 100644 index 00000000000..4502454cf20 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/subscriptions.go @@ -0,0 +1,228 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "reflect" + "time" + + "github.com/lf-edge/eve/pkg/pillar/hypervisor" + "github.com/lf-edge/eve/pkg/pillar/pubsub" + "github.com/lf-edge/eve/pkg/pillar/types" +) + +func (usbCtx *usbmanagerContext) process(ps *pubsub.PubSub, subs []pubsub.Subscription) { + stillRunning := time.NewTicker(stillRunningInterval) + + cases := make([]reflect.SelectCase, 0) + for _, subscription := range subs { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(subscription.MsgChan()), + }) + } + + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(stillRunning.C), + }) + + for { + index, value, _ := reflect.Select(cases) + if index >= len(subs) { + ps.StillRunning(agentName, warningTime, errorTime) + continue + } + change, ok := value.Interface().(pubsub.Change) + if !ok { + continue + } + subs[index].ProcessChange(change) + + } +} + +func (usbCtx *usbmanagerContext) subscribe(ps *pubsub.PubSub) { + subAssignableAdapters, err := ps.NewSubscription(pubsub.SubscriptionOptions{ + AgentName: "domainmgr", + MyAgentName: agentName, + TopicImpl: types.AssignableAdapters{}, + Activate: false, + CreateHandler: usbCtx.handleAssignableAdaptersCreate, + ModifyHandler: usbCtx.handleAssignableAdaptersModify, + DeleteHandler: usbCtx.handleAssignableAdaptersDelete, + WarningTime: warningTime, + ErrorTime: errorTime, + }) + if err != nil { + log.Fatal(err) + } + + usbCtx.SubAssignableAdapters = subAssignableAdapters + + subDomainStatus, err := ps.NewSubscription(pubsub.SubscriptionOptions{ + AgentName: "domainmgr", + MyAgentName: agentName, + TopicImpl: types.DomainStatus{}, + Activate: false, + CreateHandler: usbCtx.handleDomainStatusCreate, + ModifyHandler: usbCtx.handleDomainStatusModify, + DeleteHandler: usbCtx.handleDomainStatusDelete, + WarningTime: warningTime, + ErrorTime: errorTime, + }) + if err != nil { + log.Fatal(err) + } + usbCtx.SubDomainStatus = subDomainStatus + + for _, sub := range pubsub.SubscriptionsByReflection(usbCtx) { + err := sub.Activate() + if err != nil { + log.Fatalf("cannot subscribe to %+v: %+v", sub, err) + } + } + + usbCtx.controller.listenUSBPorts() +} + +func (usbCtx *usbmanagerContext) handleDomainStatusModify(_ interface{}, _ string, + statusArg interface{}, _ interface{}) { + newstatus, ok := statusArg.(types.DomainStatus) + if !ok { + log.Warnf("newstatus not OK, got %+v type %T\n", statusArg, statusArg) + return + } + + _, isRunningDomain := usbCtx.runningDomains[newstatus.DomainName] + + if newstatus.State == types.RUNNING && !isRunningDomain { + usbCtx.runningDomains[newstatus.DomainName] = struct{}{} + usbCtx.handleDomainStatusRunning(newstatus) + } + + if newstatus.State != types.RUNNING && isRunningDomain { + delete(usbCtx.runningDomains, newstatus.DomainName) + usbCtx.handleDomainStatusNotRunning(newstatus) + } + +} + +func (usbCtx *usbmanagerContext) handleDomainStatusNotRunning(status types.DomainStatus) { + qmp := hypervisor.GetQmpExecutorSocket(status.DomainName) + vm := newVirtualmachine(qmp, nil) + + usbCtx.controller.removeVirtualmachine(vm) +} + +func (usbCtx *usbmanagerContext) handleDomainStatusRunning(status types.DomainStatus) { + qmp := hypervisor.GetQmpExecutorSocket(status.DomainName) + vm := newVirtualmachine(qmp, nil) + + log.Tracef("display name: %+v\n", status.DisplayName) + for _, io := range status.IoAdapterList { + vm.addAdapter(io.Name) + } + + usbCtx.controller.addVirtualmachine(vm) + usbCtx.controller.Lock() + log.Tracef("rule engine: %s\n", usbCtx.controller.ruleEngine) + usbCtx.controller.Unlock() +} + +func (usbCtx *usbmanagerContext) handleDomainStatusDelete(_ interface{}, _ string, + statusArg interface{}) { + + status, ok := statusArg.(types.DomainStatus) + if !ok { + log.Warnf("status not OK, got %+v type %T\n", statusArg, statusArg) + return + } + log.Tracef("display name: %+v\n", status.DisplayName) + + usbCtx.handleDomainStatusNotRunning(status) +} + +func (usbCtx *usbmanagerContext) handleDomainStatusCreate(_ interface{}, _ string, + statusArg interface{}) { + + status, ok := statusArg.(types.DomainStatus) + if !ok { + log.Warnf("status not OK, got %+v type %T\n", statusArg, statusArg) + return + } + + log.Tracef("display name: %+v\n", status.DisplayName) + + if status.State == types.RUNNING { + usbCtx.handleDomainStatusRunning(status) + } +} + +func (usbCtx *usbmanagerContext) handleAssignableAdaptersCreate(_ interface{}, _ string, + statusArg interface{}) { + assignableAdapters, ok := statusArg.(types.AssignableAdapters) + if !ok { + log.Warnf("status not OK, got %+v type %T\n", statusArg, statusArg) + return + } + + for _, adapter := range assignableAdapters.IoBundleList { + usbCtx.controller.addIOBundle(adapter) + } +} + +func (usbCtx *usbmanagerContext) handleAssignableAdaptersModify(_ interface{}, _ string, + statusArg interface{}, oldStatusArg interface{}) { + + oldAssignableAdapters, ok := oldStatusArg.(types.AssignableAdapters) + if !ok { + log.Warnf("oldstatus not OK, got %+v type %T\n", oldStatusArg, oldStatusArg) + return + } + newAssignableAdapters, ok := statusArg.(types.AssignableAdapters) + if !ok { + log.Warnf("newstatus not OK, got %+v type %T\n", statusArg, statusArg) + return + } + + oldAssignableAdaptersMap := make(map[string]types.IoBundle) + + for _, adapter := range oldAssignableAdapters.IoBundleList { + oldAssignableAdaptersMap[adapter.Phylabel] = adapter + } + + newAssignableAdaptersMap := make(map[string]types.IoBundle) + + for _, adapter := range newAssignableAdapters.IoBundleList { + newAssignableAdaptersMap[adapter.Phylabel] = adapter + } + + for adapterName, adapter := range oldAssignableAdaptersMap { + _, ok := newAssignableAdaptersMap[adapterName] + if !ok { + usbCtx.controller.addIOBundle(adapter) + } + } + + for adapterName, adapter := range newAssignableAdaptersMap { + _, ok := oldAssignableAdaptersMap[adapterName] + if !ok { + usbCtx.controller.removeIOBundle(adapter) + } + } +} + +func (usbCtx *usbmanagerContext) handleAssignableAdaptersDelete(_ interface{}, _ string, + statusArg interface{}) { + + assignableAdapters, ok := statusArg.(types.AssignableAdapters) + if !ok { + log.Warnf("status not OK, got %+v type %T\n", statusArg, statusArg) + return + } + + for _, adapter := range assignableAdapters.IoBundleList { + usbCtx.controller.removeIOBundle(adapter) + } +} diff --git a/pkg/pillar/cmd/usbmanager/usbcontroller.go b/pkg/pillar/cmd/usbmanager/usbcontroller.go new file mode 100644 index 00000000000..286ccc25711 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbcontroller.go @@ -0,0 +1,255 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "strconv" + "strings" + "sync" + + "github.com/lf-edge/eve/pkg/pillar/hypervisor" + "github.com/lf-edge/eve/pkg/pillar/types" +) + +const sysFSPath = "/sys" + +type usbmanagerController struct { + ruleEngine *ruleEngine + name2deviceRule map[string]passthroughRule + + usbpassthroughs usbpassthroughs + + connectUSBDeviceToQemu func(up usbpassthrough) + disconnectUSBDeviceFromQemu func(up usbpassthrough) + + listenUSBStopChan chan struct{} + + sync.Mutex +} + +func (uc *usbmanagerController) init() { + uc.Lock() + uc.ruleEngine = newRuleEngine() + + usbNetworkAdapterForbidPassthroughRule := newUsbNetworkAdapterForbidPassthroughRule() + uc.ruleEngine.addRule(&usbNetworkAdapterForbidPassthroughRule) + uc.ruleEngine.addRule(&usbHubForbidPassthroughRule{}) + + uc.name2deviceRule = make(map[string]passthroughRule) + + uc.usbpassthroughs = newUsbpassthroughs() + + uc.connectUSBDeviceToQemu = uc.connectUSBDeviceToQemuImpl + uc.disconnectUSBDeviceFromQemu = uc.disconnectUSBDeviceFromQemuImpl + + uc.Unlock() +} + +// prevents trying to connect a usb device twice +func (uc *usbmanagerController) connectUSBDeviceToQemuIdempotent(up usbpassthrough) { + if uc.usbpassthroughs.hasUsbpassthrough(up) { + log.Warnf("%+v is already passed through\n", up) + return + } + uc.usbpassthroughs.addUsbpassthrough(&up) + uc.connectUSBDeviceToQemu(up) +} + +// prevents trying to disconnect a usb device twice +func (uc *usbmanagerController) disconnectUSBDeviceFromQemuIdempotent(up usbpassthrough) { + if !uc.usbpassthroughs.hasUsbpassthrough(up) { + return + } + uc.usbpassthroughs.delUsbpassthrough(&up) + uc.disconnectUSBDeviceFromQemu(up) +} + +func (uc *usbmanagerController) connectUSBDeviceToQemuImpl(up usbpassthrough) { + log.Tracef("connect usb passthrough %+v to %s\n", up, up.vm.qmpSocketPath) + + err := hypervisor.QmpExecDeviceAdd(up.vm.qmpSocketPath, up.usbdevice.qemuDeviceName(), up.usbdevice.busnum, up.usbdevice.devnum) + if err != nil { + log.Warnf("connect qmp failed: %+v\n", err) + } +} + +func (uc *usbmanagerController) disconnectUSBDeviceFromQemuImpl(up usbpassthrough) { + log.Tracef("disconnect usb passthrough %+v to %s\n", up, up.vm.qmpSocketPath) + + err := hypervisor.QmpExecDeviceDelete(up.vm.qmpSocketPath, up.usbdevice.qemuDeviceName()) + if err != nil { + log.Warnf("disconnect qmp failed: %+v\n", err) + } +} + +func (uc *usbmanagerController) addUSBDevice(ud usbdevice) { + uc.Lock() + defer uc.Unlock() + + uc.usbpassthroughs.addUsbdevice(&ud) + vm := uc.ruleEngine.apply(ud) + log.Tracef("add usb device %+v vm=%v; rules: %s\n", ud, vm, uc.ruleEngine.String()) + if vm != nil { + uc.connectUSBDeviceToQemuIdempotent(usbpassthrough{ + usbdevice: &ud, + vm: vm, + }) + } +} + +func (uc *usbmanagerController) removeUSBDevice(ud usbdevice) { + uc.Lock() + defer uc.Unlock() + vm := uc.ruleEngine.apply(ud) + log.Tracef("remove usb device %+v vm=%v; rules: %s\n", ud, vm, uc.ruleEngine.String()) + if vm != nil { + uc.disconnectUSBDeviceFromQemuIdempotent(usbpassthrough{ + usbdevice: &ud, + vm: vm, + }) + uc.usbpassthroughs.delUsbdevice(&ud) + + } +} + +func (uc *usbmanagerController) addVirtualmachine(vm virtualmachine) { + uc.Lock() + defer uc.Unlock() + log.Tracef("add vm %+v", vm) + uc.usbpassthroughs.addVM(&vm) + + // add rules + for _, phyLabel := range vm.adapters { + ioBundle := uc.usbpassthroughs.ioBundles[phyLabel] + if ioBundle == nil { + continue + } + + pr := ioBundle2PassthroughRule(*ioBundle) + pr.setVirtualMachine(&vm) + + uc.ruleEngine.addRule(pr) + } + + // find and connect usb device + for _, ud := range uc.usbpassthroughs.listUsbdevices() { + vm := uc.ruleEngine.apply(*ud) + if vm == nil { + continue + } + uc.connectUSBDeviceToQemuIdempotent(usbpassthrough{ + usbdevice: ud, + vm: vm, + }) + } +} + +func (uc *usbmanagerController) removeVirtualmachine(vm virtualmachine) { + uc.Lock() + defer uc.Unlock() + ups := uc.usbpassthroughs.usbpassthroughsOfVM(vm) + for _, up := range ups { + uc.disconnectUSBDeviceFromQemuIdempotent(*up) + uc.usbpassthroughs.delUsbpassthrough(up) + } + for _, phyLabel := range vm.adapters { + ioBundle := uc.usbpassthroughs.ioBundles[phyLabel] + if ioBundle == nil { + continue + } + + pr := ioBundle2PassthroughRule(*ioBundle) + uc.ruleEngine.delRule(pr) + } + + uc.usbpassthroughs.delVM(&vm) +} + +func (uc *usbmanagerController) removeIOBundle(ioBundle types.IoBundle) { + uc.Lock() + defer uc.Unlock() + + uc.usbpassthroughs.delIoBundle(&ioBundle) + + pr := ioBundle2PassthroughRule(ioBundle) + vm := uc.usbpassthroughs.vmsByIoBundlePhyLabel[ioBundle.Phylabel] + pr.setVirtualMachine(vm) + + uc.ruleEngine.delRule(pr) + + for _, ud := range uc.usbpassthroughs.listUsbdevices() { + vm := uc.ruleEngine.apply(*ud) + if vm == nil { + continue + } + uc.disconnectUSBDeviceFromQemuIdempotent(usbpassthrough{ + usbdevice: ud, + vm: vm, + }) + } +} + +func (uc *usbmanagerController) addIOBundle(ioBundle types.IoBundle) { + uc.Lock() + defer uc.Unlock() + log.Tracef("add iobundle %s: %s/%s\n", ioBundle.Phylabel, ioBundle.UsbAddr, ioBundle.PciLong) + uc.usbpassthroughs.addIoBundle(&ioBundle) + + pr := ioBundle2PassthroughRule(ioBundle) + if pr == nil { + log.Tracef("unusable iobundle %s: %s/%s\n", ioBundle.Phylabel, ioBundle.UsbAddr, ioBundle.PciLong) + return + } + vm := uc.usbpassthroughs.vmsByIoBundlePhyLabel[ioBundle.Phylabel] + pr.setVirtualMachine(vm) + + uc.ruleEngine.addRule(pr) + + for _, ud := range uc.usbpassthroughs.listUsbdevices() { + vm := uc.ruleEngine.apply(*ud) + if vm == nil { + continue + } + uc.connectUSBDeviceToQemuIdempotent(usbpassthrough{ + usbdevice: ud, + vm: vm, + }) + } +} + +func (uc *usbmanagerController) cancel() { + uc.listenUSBStopChan <- struct{}{} +} + +func ioBundle2PassthroughRule(adapter types.IoBundle) passthroughRule { + var pr passthroughRule + + if adapter.UsbAddr == "" && adapter.PciLong != "" { + pci := pciPassthroughRule{pciAddress: adapter.PciLong} + + pr = &pci + } else if adapter.UsbAddr != "" { + usbParts := strings.SplitN(adapter.UsbAddr, ":", 2) + if len(usbParts) != 2 { + log.Warnf("usbaddr %s not parseable", adapter.UsbAddr) + return nil + } + busnum, err := strconv.ParseUint(usbParts[0], 10, 16) + if err != nil { + panic(err) + } + portnum := usbParts[1] + ud := usbdevice{ + busnum: uint16(busnum), + portnum: portnum, + usbControllerPCIAddress: adapter.PciLong, + } + usb := usbPortPassthroughRule{ud: ud} + + pr = &usb + } else { + pr = nil + } + + return pr +} diff --git a/pkg/pillar/cmd/usbmanager/usbcontroller_test.go b/pkg/pillar/cmd/usbmanager/usbcontroller_test.go new file mode 100644 index 00000000000..aad5372d571 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbcontroller_test.go @@ -0,0 +1,288 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "fmt" + "sync/atomic" + "testing" + + "github.com/lf-edge/eve/pkg/pillar/types" +) + +type testingEvent int + +const ( + ioBundleTestEvent testingEvent = iota + usbTestEvent + vmTestEvent +) + +type testEventTableGenerator struct { + testEventTable [][]testingEvent +} + +func (testEvent testingEvent) String() string { + if testEvent == ioBundleTestEvent { + return "IOBundle Event" + } else if testEvent == usbTestEvent { + return "USB Event" + } else if testEvent == vmTestEvent { + return "VM Event" + } + + return "" +} + +func (tetg *testEventTableGenerator) generate(k int, testEvents []testingEvent) { + // Heap's algorithm + if k == 1 { + testEventsCopy := make([]testingEvent, len(testEvents)) + copy(testEventsCopy, testEvents) + + tetg.testEventTable = append(tetg.testEventTable, testEventsCopy) + return + } + + tetg.generate(k-1, testEvents) + + for i := 0; i < k-1; i++ { + var swapIndex int + if k%2 == 0 { + swapIndex = i + } else { + swapIndex = 0 + } + t := testEvents[swapIndex] + testEvents[swapIndex] = testEvents[k-1] + testEvents[k-1] = t + + tetg.generate(k-1, testEvents) + } +} + +func testEventTable() [][]testingEvent { + testEvents := []testingEvent{ioBundleTestEvent, usbTestEvent, vmTestEvent} + + var tetg testEventTableGenerator + + tetg.generate(len(testEvents), testEvents) + + return tetg.testEventTable +} + +func TestRemovingVm(t *testing.T) { + usbEventBusnum := uint16(1) + usbEventDevnum := uint16(2) + usbEventPortnum := "3.1" + ioBundleUsbAddr := fmt.Sprintf("%d:%s", usbEventBusnum, usbEventPortnum) + ioBundlePciLong := "00:02.0" + ioBundleLabel := "TOUCH" + vmAdapter := ioBundleLabel + usbEventPCIAddress := ioBundlePciLong + qmpSocketPath := "/vm/qemu.sock" + + ioBundle, ud, vm := newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong, + usbEventBusnum, usbEventDevnum, usbEventPortnum, usbEventPCIAddress, + qmpSocketPath, vmAdapter) + + uc := newTestUsbmanagerController() + uc.connectUSBDeviceToQemu = func(up usbpassthrough) { + t.Logf("connecting usbpassthrough: %+v", up.String()) + } + uc.disconnectUSBDeviceFromQemu = func(up usbpassthrough) { + t.Logf("disconnecting usbpassthrough: %+v", up.String()) + } + + uc.addIOBundle(ioBundle) + uc.addUSBDevice(ud) + uc.addVirtualmachine(vm) + if len(uc.usbpassthroughs.usbpassthroughs) != 1 { + t.Fatalf("invalid amount of usbpassthroughs registered") + } + if len(uc.usbpassthroughs.vms) != 1 || len(uc.usbpassthroughs.vmsByIoBundlePhyLabel) != 1 { + t.Fatalf("invalid amount of vms registered") + } + uc.removeVirtualmachine(vm) + if len(uc.usbpassthroughs.usbpassthroughs) != 0 { + t.Fatalf("invalid amount of usbpassthroughs registered") + } + if len(uc.usbpassthroughs.vms) != 0 || len(uc.usbpassthroughs.vmsByIoBundlePhyLabel) != 0 { + t.Fatalf("invalid amount of vms registered") + } +} + +func TestNoConnectWrongPCIUSBDevicesToQemu(t *testing.T) { + usbEventBusnum := uint16(1) + usbEventDevnum := uint16(2) + usbEventPortnum := "3.1" + ioBundleUsbAddr := fmt.Sprintf("%d:%s", usbEventBusnum, usbEventPortnum) + ioBundlePciLong := "00:02.0" + ioBundleLabel := "TOUCH" + vmAdapter := ioBundleLabel + usbEventPCIAddress := "" + qmpSocketPath := "/vm/qemu.sock" + + ioBundle, usbdevice, vm := newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong, + usbEventBusnum, usbEventDevnum, usbEventPortnum, usbEventPCIAddress, + qmpSocketPath, vmAdapter) + + tet := testEventTable() + countUSBConnections := testRunConnectingUsbDevicesOrderCombinations(tet, qmpSocketPath, ioBundle, usbdevice, vm) + + if countUSBConnections != 0 { + t.Fatalf("expected 0 connection attempts to qemu, but got %d", countUSBConnections) + } +} + +func TestNoConnectUSBDevicesToQemu(t *testing.T) { + usbEventBusnum := uint16(1) + usbEventDevnum := uint16(2) + usbEventPortnum := "3.1" + ioBundleUsbAddr := fmt.Sprintf("%d:%s-1", usbEventBusnum, usbEventPortnum) // usb port different from usb device + ioBundlePciLong := "00:02.0" + ioBundleLabel := "TOUCH" + vmAdapter := ioBundleLabel + usbEventPCIAddress := ioBundlePciLong + qmpSocketPath := "/vm/qemu.sock" + + ioBundle, usbdevice, vm := newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong, + usbEventBusnum, usbEventDevnum, usbEventPortnum, usbEventPCIAddress, + qmpSocketPath, vmAdapter) + + tet := testEventTable() + countUSBConnections := testRunConnectingUsbDevicesOrderCombinations(tet, qmpSocketPath, ioBundle, usbdevice, vm) + + if countUSBConnections != 0 { + t.Fatalf("expected 0 connection attempts to qemu, but got %d", countUSBConnections) + } +} + +func TestReconnectUSBDevicesToQemu(t *testing.T) { + usbEventBusnum := uint16(1) + usbEventDevnum := uint16(2) + usbEventPortnum := "3.1" + ioBundleUsbAddr := fmt.Sprintf("%d:%s", usbEventBusnum, usbEventPortnum) + ioBundlePciLong := "00:02.0" + ioBundleLabel := "TOUCH" + vmAdapter := ioBundleLabel + usbEventPCIAddress := ioBundlePciLong + qmpSocketPath := "/vm/qemu.sock" + + ioBundle, ud, vm := newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong, + usbEventBusnum, usbEventDevnum, usbEventPortnum, usbEventPCIAddress, + qmpSocketPath, vmAdapter) + + uc := newTestUsbmanagerController() + var countCurrentUSBPassthroughs atomic.Int32 + countCurrentUSBPassthroughs.Store(0) + uc.connectUSBDeviceToQemu = func(up usbpassthrough) { + countCurrentUSBPassthroughs.Add(1) + } + uc.disconnectUSBDeviceFromQemu = func(up usbpassthrough) { + countCurrentUSBPassthroughs.Add(-1) + } + + uc.addIOBundle(ioBundle) + uc.addUSBDevice(ud) + uc.addUSBDevice(ud) + uc.addVirtualmachine(vm) + uc.addUSBDevice(ud) + if countCurrentUSBPassthroughs.Load() != 1 { + t.Fatalf("expected current usb passthrough count to be 1, but got %d", countCurrentUSBPassthroughs.Load()) + } + uc.removeUSBDevice(ud) + uc.removeUSBDevice(ud) + if countCurrentUSBPassthroughs.Load() != 0 { + t.Fatalf("expected current usb passthrough count to be 0, but got %d", countCurrentUSBPassthroughs.Load()) + } + + uc.addUSBDevice(ud) + if countCurrentUSBPassthroughs.Load() != 1 { + t.Fatalf("expected current usb passthrough count to be 1, but got %d", countCurrentUSBPassthroughs.Load()) + } + uc.addUSBDevice(ud) + + if countCurrentUSBPassthroughs.Load() != 1 { + t.Fatalf("expected current usb passthrough count to be 1, but got %d", countCurrentUSBPassthroughs.Load()) + } +} + +func TestConnectUSBDevicesToQemu(t *testing.T) { + usbEventBusnum := uint16(1) + usbEventDevnum := uint16(2) + usbEventPortnum := "3.1" + ioBundleUsbAddr := fmt.Sprintf("%d:%s", usbEventBusnum, usbEventPortnum) + ioBundlePciLong := "00:02.0" + ioBundleLabel := "TOUCH" + vmAdapter := ioBundleLabel + usbEventPCIAddress := ioBundlePciLong + qmpSocketPath := "/vm/qemu.sock" + + ioBundle, usbdevice, vm := newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong, + usbEventBusnum, usbEventDevnum, usbEventPortnum, usbEventPCIAddress, + qmpSocketPath, vmAdapter) + + tet := testEventTable() + countUSBConnections := testRunConnectingUsbDevicesOrderCombinations(tet, qmpSocketPath, ioBundle, usbdevice, vm) + + if len(tet) != countUSBConnections { + t.Fatalf("expected %d connection attempts to qemu, but got %d", len(tet), countUSBConnections) + } +} + +func testRunConnectingUsbDevicesOrderCombinations(tet [][]testingEvent, expectedQmpSocketPath string, ioBundle types.IoBundle, ud usbdevice, vm virtualmachine) int { + countUSBConnections := 0 + for _, testEvents := range tet { + uc := newTestUsbmanagerController() + + uc.connectUSBDeviceToQemu = func(up usbpassthrough) { + if up.vm.qmpSocketPath != expectedQmpSocketPath { + err := fmt.Errorf("vm connecting to should have qmp path %s, but has %s", expectedQmpSocketPath, up.vm.qmpSocketPath) + panic(err) + } + countUSBConnections++ + } + + for _, testEvent := range testEvents { + if testEvent == ioBundleTestEvent { + uc.addIOBundle(ioBundle) + } else if testEvent == usbTestEvent { + uc.addUSBDevice(ud) + } else if testEvent == vmTestEvent { + uc.addVirtualmachine(vm) + } + } + + } + return countUSBConnections +} + +func newTestVirtualPassthroughEnv(ioBundleLabel, ioBundleUsbAddr, ioBundlePciLong string, + usbEventBusnum, usbEventDevnum uint16, usbEventPortnum string, usbEventPCIAddress, + qmpSocketPath, vmAdapter string) (types.IoBundle, usbdevice, virtualmachine) { + + ioBundle := types.IoBundle{Phylabel: ioBundleLabel, UsbAddr: ioBundleUsbAddr, PciLong: ioBundlePciLong} + + ud := usbdevice{ + busnum: usbEventBusnum, + devnum: usbEventDevnum, + portnum: usbEventPortnum, + vendorID: 05, + productID: 06, + usbControllerPCIAddress: usbEventPCIAddress, + } + vm := virtualmachine{ + qmpSocketPath: qmpSocketPath, + adapters: []string{vmAdapter}, + } + + return ioBundle, ud, vm +} + +func newTestUsbmanagerController() *usbmanagerController { + uc := usbmanagerController{} + uc.init() + + return &uc +} diff --git a/pkg/pillar/cmd/usbmanager/usbdevice.go b/pkg/pillar/cmd/usbmanager/usbdevice.go new file mode 100644 index 00000000000..69ae64dd647 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbdevice.go @@ -0,0 +1,58 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "fmt" + "net/url" + "strings" +) + +type usbType int32 + +const ( + other usbType = -1 + device usbType = 0 + audio usbType = 1 + hid usbType = 3 + hub usbType = 9 +) + +type usbAction uint8 + +func (ud usbdevice) String() string { + return fmt.Sprintf("busnum: %s devnum: %s product: %s/%s parentPCIAddress: %s ueventFilePath: %s", ud.busnumString(), ud.devnumString(), ud.vendorIDString(), ud.productIDString(), ud.usbControllerPCIAddress, ud.ueventFilePath) +} + +func (ud usbdevice) vendorIDString() string { + return fmt.Sprintf("%04x", ud.vendorID) +} + +func (ud usbdevice) productIDString() string { + return fmt.Sprintf("%04x", ud.productID) +} + +func (ud usbdevice) vendorAndproductIDString() string { + return fmt.Sprintf("%s:%s", ud.vendorIDString(), ud.productIDString()) +} + +func (ud usbdevice) busnumAndDevnumString() string { + return fmt.Sprintf("%s:%s", ud.busnumString(), ud.devnumString()) +} + +func (ud usbdevice) busnumString() string { + return fmt.Sprintf("%03x", ud.busnum) +} + +func (ud usbdevice) devnumString() string { + return fmt.Sprintf("%03x", ud.devnum) +} + +func (ud usbdevice) qemuDeviceName() string { + id := fmt.Sprintf("USB%s@%d/%d", ud.usbControllerPCIAddress, ud.busnum, ud.devnum) + id = url.QueryEscape(id) + id = strings.ReplaceAll(id, "%", "") + id = strings.ReplaceAll(id, "-", "") + id = strings.ReplaceAll(id, ".", "") + return id +} diff --git a/pkg/pillar/cmd/usbmanager/usbevent.go b/pkg/pillar/cmd/usbmanager/usbevent.go new file mode 100644 index 00000000000..8b737827c80 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbevent.go @@ -0,0 +1,82 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "errors" + "io" + "path/filepath" + "sync/atomic" + + "github.com/eshard/uevent" +) + +func (uc *usbmanagerController) listenUSBPorts() { + var isCancelled atomic.Bool + + usbdevices := make(map[string]*usbdevice) + isCancelled.Store(false) + + r, err := uevent.NewReader() + if err != nil { + log.Fatal(err) + } + + // goroutine for cancelling uevent reader + uc.listenUSBStopChan = make(chan struct{}) + go func() { + <-uc.listenUSBStopChan + isCancelled.Store(true) + r.Close() + return + }() + + go func() { + for _, ud := range walkUSBPorts() { + usbdevices[ud.ueventFilePath] = ud + log.Tracef("usb device from walking: %+v", ud) + uc.addUSBDevice(*ud) + } + + defer r.Close() + + dec := uevent.NewDecoder(r) + + for { + evt, err := dec.Decode() + if errors.Is(err, io.EOF) && isCancelled.Load() { + return + } else if err != nil { + log.Fatal(err) + } + + ueventFilePath := filepath.Join(sysFSPath, evt.Devpath, "uevent") + ud := ueventFile2usbDevice(ueventFilePath) + if ud == nil { + ud = usbdevices[ueventFilePath] + } + + if ud == nil { + continue + } + + if evt.Action == "bind" { + // bind, not add: https://github.com/olavmrk/usb-libvirt-hotplug/issues/4 + _, ok := usbdevices[ud.ueventFilePath] + if ok { + continue + } + + usbdevices[ud.ueventFilePath] = ud + uc.addUSBDevice(*ud) + } else if evt.Action == "remove" { + ud, ok := usbdevices[ueventFilePath] + if ok { + uc.removeUSBDevice(*ud) + delete(usbdevices, ueventFilePath) + } + } + } + }() + +} diff --git a/pkg/pillar/cmd/usbmanager/usbmanager.go b/pkg/pillar/cmd/usbmanager/usbmanager.go new file mode 100644 index 00000000000..7ff3b262414 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbmanager.go @@ -0,0 +1,101 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package usbmanager + +import ( + "time" + + "github.com/lf-edge/eve/pkg/pillar/agentbase" + "github.com/lf-edge/eve/pkg/pillar/base" + "github.com/lf-edge/eve/pkg/pillar/cmd/domainmgr" + "github.com/lf-edge/eve/pkg/pillar/hypervisor" + "github.com/lf-edge/eve/pkg/pillar/pubsub" + "github.com/lf-edge/eve/pkg/pillar/utils" + "github.com/sirupsen/logrus" +) + +const ( + agentName = "usbmanager" + // Time limits for event loop handlers + errorTime = 3 * time.Minute + warningTime = 40 * time.Second + stillRunningInterval = 25 * time.Second +) + +var ( + logger *logrus.Logger + log *base.LogObject +) + +type usbdevice struct { + busnum uint16 + portnum string + devnum uint16 + vendorID uint32 + productID uint32 + devicetype string + usbControllerPCIAddress string + ueventFilePath string +} + +type usbmanagerContext struct { + agentbase.AgentBase + SubAssignableAdapters pubsub.Subscription + SubDomainStatus pubsub.Subscription + controller usbmanagerController + runningDomains map[string]struct{} // mapping DomainStatus.DomainName +} + +func newUsbmanagerContext() *usbmanagerContext { + usbCtx := &usbmanagerContext{} + usbCtx.controller.init() + + usbCtx.runningDomains = make(map[string]struct{}) + + return usbCtx +} + +// Run - Main function - invoked from zedbox.go +func Run(ps *pubsub.PubSub, loggerArg *logrus.Logger, logArg *base.LogObject, arguments []string) int { + logger = loggerArg + log = logArg + + usbCtx := newUsbmanagerContext() + + agentbase.Init(usbCtx, logger, log, agentName, + agentbase.WithPidFile(), + agentbase.WithWatchdog(ps, warningTime, errorTime), + agentbase.WithArguments(arguments)) + + // Wait until we have been onboarded aka know our own UUID, but we don't use the UUID + err := utils.WaitForOnboarded(ps, log, agentName, warningTime, errorTime) + if err != nil { + log.Fatal(err) + } + log.Functionf("processed onboarded") + + if err := utils.WaitForVault(ps, log, agentName, warningTime, errorTime); err != nil { + log.Fatal(err) + } + log.Functionf("processed Vault Status") + + var subs []pubsub.Subscription + + currentHypervisor := domainmgr.CurrentHypervisor() + fmt.Fprintf(os.Stderr, "BBBBB hv: %+v (%T)\n", currentHypervisor, currentHypervisor) + _, ok := currentHypervisor.(hypervisor.KvmContext) + if ok { + usbCtx.subscribe(ps) + subs = pubsub.SubscriptionsByReflection(usbCtx) + } else { + subs = make([]pubsub.Subscription, 0) + log.Warnf("usbmanager is disabled as hypervisor %s is used\n", currentHypervisor.Name()) + } + + usbCtx.process(ps, subs) + + return 0 +} diff --git a/pkg/pillar/cmd/usbmanager/usbmanager_test.go b/pkg/pillar/cmd/usbmanager/usbmanager_test.go new file mode 100644 index 00000000000..37a534ec4c9 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbmanager_test.go @@ -0,0 +1,11 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "github.com/lf-edge/eve/pkg/pillar/agentlog" +) + +func init() { + logger, log = agentlog.Init(agentName) +} diff --git a/pkg/pillar/cmd/usbmanager/usbpassthrough.go b/pkg/pillar/cmd/usbmanager/usbpassthrough.go new file mode 100644 index 00000000000..74be16985b7 --- /dev/null +++ b/pkg/pillar/cmd/usbmanager/usbpassthrough.go @@ -0,0 +1,113 @@ +// Copyright (c) 2023 Zededa, Inc. +// SPDX-License-Identifier: Apache-2.0 +package usbmanager + +import ( + "strings" + + "github.com/lf-edge/eve/pkg/pillar/types" +) + +// the holy trinity +type usbpassthrough struct { + usbdevice *usbdevice + vm *virtualmachine +} + +func (up usbpassthrough) String() string { + return strings.Join([]string{ + up.usbdevice.String(), + up.vm.qmpSocketPath, + }, "||") +} + +type usbpassthroughs struct { + ioBundles map[string]*types.IoBundle // PhyLabel is key + usbdevices map[string]*usbdevice // by ueventFilePath + vms map[string]*virtualmachine // by qmp path + vmsByIoBundlePhyLabel map[string]*virtualmachine + + usbpassthroughs map[string]*usbpassthrough + usbpassthroughsByVM map[string]map[string]*usbpassthrough +} + +func newUsbpassthroughs() usbpassthroughs { + var up usbpassthroughs + + up.ioBundles = make(map[string]*types.IoBundle) + up.usbdevices = make(map[string]*usbdevice) + up.vms = make(map[string]*virtualmachine) + up.vmsByIoBundlePhyLabel = make(map[string]*virtualmachine) + + up.usbpassthroughs = make(map[string]*usbpassthrough) + up.usbpassthroughsByVM = make(map[string]map[string]*usbpassthrough) + + return up +} + +func (ups *usbpassthroughs) delIoBundle(ioBundle *types.IoBundle) { + delete(ups.ioBundles, ioBundle.Phylabel) +} + +func (ups *usbpassthroughs) addIoBundle(ioBundle *types.IoBundle) { + ups.ioBundles[ioBundle.Phylabel] = ioBundle +} + +func (ups *usbpassthroughs) listUsbdevices() []*usbdevice { + usbdevices := make([]*usbdevice, 0) + + for _, ud := range ups.usbdevices { + usbdevices = append(usbdevices, ud) + } + + return usbdevices +} + +func (ups *usbpassthroughs) addUsbdevice(ud *usbdevice) { + ups.usbdevices[ud.ueventFilePath] = ud +} + +func (ups *usbpassthroughs) delUsbdevice(ud *usbdevice) { + delete(ups.usbdevices, ud.ueventFilePath) +} + +func (ups *usbpassthroughs) addVM(vm *virtualmachine) { + ups.vms[vm.qmpSocketPath] = vm + for _, phyLabel := range vm.adapters { + ups.vmsByIoBundlePhyLabel[phyLabel] = vm + } +} + +func (ups *usbpassthroughs) delVM(vm *virtualmachine) { + vmDel := ups.vms[vm.qmpSocketPath] + if vmDel != nil && vmDel.adapters != nil { + for _, phyLabel := range vmDel.adapters { + delete(ups.vmsByIoBundlePhyLabel, phyLabel) + } + } + delete(ups.vms, vm.qmpSocketPath) +} + +func (ups usbpassthroughs) hasUsbpassthrough(up usbpassthrough) bool { + _, ok := ups.usbpassthroughs[up.String()] + return ok +} + +func (ups *usbpassthroughs) addUsbpassthrough(up *usbpassthrough) { + ups.usbpassthroughs[up.String()] = up + if ups.usbpassthroughsByVM[up.vm.qmpSocketPath] == nil { + ups.usbpassthroughsByVM[up.vm.qmpSocketPath] = make(map[string]*usbpassthrough) + } + ups.usbpassthroughsByVM[up.vm.qmpSocketPath][up.usbdevice.ueventFilePath] = up +} + +func (ups *usbpassthroughs) delUsbpassthrough(up *usbpassthrough) { + delete(ups.usbpassthroughs, up.String()) + delete(ups.usbpassthroughsByVM, up.vm.qmpSocketPath) + + ups.delVM(up.vm) +} + +func (ups usbpassthroughs) usbpassthroughsOfVM(vm virtualmachine) map[string]*usbpassthrough { + return ups.usbpassthroughsByVM[vm.qmpSocketPath] +} diff --git a/pkg/pillar/hypervisor/kvm.go b/pkg/pillar/hypervisor/kvm.go index 8e888f998c6..65c896fdb9f 100644 --- a/pkg/pillar/hypervisor/kvm.go +++ b/pkg/pillar/hypervisor/kvm.go @@ -356,7 +356,7 @@ const sysfsVfioPciBind = "/sys/bus/pci/drivers/vfio-pci/bind" const sysfsPciDriversProbe = "/sys/bus/pci/drivers_probe" const vfioDriverPath = "/sys/bus/pci/drivers/vfio-pci" -// KVM domains map 1-1 to anchor device model UNIX processes (qemu or firecracker) +// KvmContext is a KVM domains map 0-1 to anchor device model UNIX processes (qemu or firecracker) // For every anchor process we maintain the following entry points in the // /run/hypervisor/kvm/DOMAIN_NAME: // @@ -412,6 +412,7 @@ func newKvm() Hypervisor { return nil } +// GetCapabilities returns capabilities of the kvm hypervisor func (ctx KvmContext) GetCapabilities() (*types.Capabilities, error) { if ctx.capabilities != nil { return ctx.capabilities, nil @@ -443,10 +444,12 @@ func (ctx KvmContext) checkIOVirtualisation() (bool, error) { return false, err } +// Name returns the name of the kvm hypervisor func (ctx KvmContext) Name() string { return KVMHypervisorName } +// Task returns either the kvm context or the containerd context depending on the domain status func (ctx KvmContext) Task(status *types.DomainStatus) types.Task { if status.VirtualizationMode == types.NOHYPER { return ctx.ctrdContext @@ -626,6 +629,7 @@ func vmmOverhead(domainName string, config types.DomainConfig, return overhead, nil } +// Setup sets up kvm func (ctx KvmContext) Setup(status types.DomainStatus, config types.DomainConfig, aa *types.AssignableAdapters, globalConfig *types.ConfigItemValueMap, file *os.File) error { @@ -682,6 +686,7 @@ func (ctx KvmContext) Setup(status types.DomainStatus, config types.DomainConfig return nil } +// CreateDomConfig creates a domain config (a qemu config file, typically named something like xen-%d.cfg) func (ctx KvmContext) CreateDomConfig(domainName string, config types.DomainConfig, status types.DomainStatus, diskStatusList []types.DiskStatus, aa *types.AssignableAdapters, file *os.File) error { tmplCtx := struct { @@ -870,6 +875,7 @@ func waitForQmp(domainName string, available bool) error { } } +// Start starts a domain func (ctx KvmContext) Start(domainName string) error { logrus.Infof("starting KVM domain %s", domainName) if err := ctx.ctrdContext.Start(domainName); err != nil { @@ -911,6 +917,7 @@ func (ctx KvmContext) Start(domainName string) error { return nil } +// Stop stops a domain func (ctx KvmContext) Stop(domainName string, _ bool) error { if err := execShutdown(GetQmpExecutorSocket(domainName)); err != nil { return logError("Stop: failed to execute shutdown command %v", err) @@ -918,6 +925,7 @@ func (ctx KvmContext) Stop(domainName string, _ bool) error { return nil } +// Delete deletes a domain func (ctx KvmContext) Delete(domainName string) (result error) { //Sending a stop signal to then domain before quitting. This is done to freeze the domain before quitting it. execStop(GetQmpExecutorSocket(domainName)) @@ -932,6 +940,7 @@ func (ctx KvmContext) Delete(domainName string) (result error) { return nil } +// Info returns information of a domain func (ctx KvmContext) Info(domainName string) (int, types.SwState, error) { // first we ask for the task status effectiveDomainID, effectiveDomainState, err := ctx.ctrdContext.Info(domainName) @@ -968,6 +977,7 @@ func (ctx KvmContext) Info(domainName string) (int, types.SwState, error) { } } +// Cleanup cleans up a domain func (ctx KvmContext) Cleanup(domainName string) error { if err := ctx.ctrdContext.Cleanup(domainName); err != nil { return fmt.Errorf("couldn't cleanup task %s: %v", domainName, err) @@ -979,6 +989,7 @@ func (ctx KvmContext) Cleanup(domainName string) error { return nil } +// PCIReserve reserves a PCI device func (ctx KvmContext) PCIReserve(long string) error { logrus.Infof("PCIReserve long addr is %s", long) @@ -1017,6 +1028,7 @@ func (ctx KvmContext) PCIReserve(long string) error { return nil } +// PCIRelease releases the PCI device reservation func (ctx KvmContext) PCIRelease(long string) error { logrus.Infof("PCIRelease long addr is %s", long) @@ -1047,6 +1059,7 @@ func (ctx KvmContext) PCIRelease(long string) error { return nil } +// PCISameController checks if two PCI controllers are the same func (ctx KvmContext) PCISameController(id1 string, id2 string) bool { tag1, err := types.PCIGetIOMMUGroup(id1) if err != nil { @@ -1069,6 +1082,7 @@ func usbBusPort(USBAddr string) (string, string) { return "", "" } +// GetQmpExecutorSocket returns the path to the qmp socket of a domain func GetQmpExecutorSocket(domainName string) string { return filepath.Join(kvmStateDir, domainName, "qmp") } diff --git a/pkg/pillar/hypervisor/qmp.go b/pkg/pillar/hypervisor/qmp.go index 9939fcd80c5..cf4128f1e38 100644 --- a/pkg/pillar/hypervisor/qmp.go +++ b/pkg/pillar/hypervisor/qmp.go @@ -67,6 +67,20 @@ func execVNCPassword(socket string, password string) error { return err } +// QmpExecDeviceDelete removes a device +func QmpExecDeviceDelete(socket, id string) error { + qmpString := fmt.Sprintf(`{ "execute": "device_del", "arguments": { "id": "%s"}}`, id) + _, err := execRawCmd(socket, qmpString) + return err +} + +// QmpExecDeviceAdd adds a usb device via busnum/devnum +func QmpExecDeviceAdd(socket, id string, busnum, devnum uint16) error { + qmpString := fmt.Sprintf(`{ "execute": "device_add", "arguments": { "driver": "usb-host", "hostbus": %d, "hostaddr": %d, "id": "%s"} }`, busnum, devnum, id) + _, err := execRawCmd(socket, qmpString) + return err +} + func getQemuStatus(socket string) (string, error) { if raw, err := execRawCmd(socket, `{ "execute": "query-status" }`); err == nil { var result struct { diff --git a/pkg/pillar/pubsub/util.go b/pkg/pillar/pubsub/util.go index 22e08ece802..bb9a3bf74e0 100644 --- a/pkg/pillar/pubsub/util.go +++ b/pkg/pillar/pubsub/util.go @@ -107,6 +107,7 @@ func ConnReadCheck(conn net.Conn) error { return sysErr } +// SubscriptionsByReflection returns all (exported) members of ctxStruct struct that are of type Subscription func SubscriptionsByReflection(ctxStruct interface{}) []Subscription { value := reflect.ValueOf(ctxStruct).Elem() valueType := value.Type() diff --git a/pkg/pillar/scripts/device-steps.sh b/pkg/pillar/scripts/device-steps.sh index d5ca62f2a44..132e518a7ef 100755 --- a/pkg/pillar/scripts/device-steps.sh +++ b/pkg/pillar/scripts/device-steps.sh @@ -18,7 +18,7 @@ ZTMPDIR=/run/global DPCDIR=$ZTMPDIR/DevicePortConfig FIRSTBOOTFILE=$ZTMPDIR/first-boot FIRSTBOOT= -AGENTS="diag zedagent ledmanager nim nodeagent domainmgr loguploader tpmmgr vaultmgr zedmanager zedrouter downloader verifier baseosmgr wstunnelclient volumemgr watcher zfsmanager" +AGENTS="diag zedagent ledmanager nim nodeagent domainmgr loguploader tpmmgr vaultmgr zedmanager zedrouter downloader verifier baseosmgr wstunnelclient volumemgr watcher zfsmanager usbmanager" TPM_DEVICE_PATH="/dev/tpmrm0" PATH=$BINDIR:$PATH TPMINFOTEMPFILE=/var/tmp/tpminfo.txt diff --git a/pkg/pillar/zedbox/zedbox.go b/pkg/pillar/zedbox/zedbox.go index bc2f421fa67..cd74bced519 100644 --- a/pkg/pillar/zedbox/zedbox.go +++ b/pkg/pillar/zedbox/zedbox.go @@ -34,6 +34,7 @@ import ( "github.com/lf-edge/eve/pkg/pillar/cmd/pbuf" "github.com/lf-edge/eve/pkg/pillar/cmd/tpmmgr" "github.com/lf-edge/eve/pkg/pillar/cmd/upgradeconverter" + "github.com/lf-edge/eve/pkg/pillar/cmd/usbmanager" "github.com/lf-edge/eve/pkg/pillar/cmd/vaultmgr" "github.com/lf-edge/eve/pkg/pillar/cmd/verifier" "github.com/lf-edge/eve/pkg/pillar/cmd/volumemgr" @@ -103,6 +104,7 @@ var ( "upgradeconverter": {f: upgradeconverter.Run, inline: inlineAlways}, "watcher": {f: watcher.Run}, "zfsmanager": {f: zfsmanager.Run}, + "usbmanager": {f: usbmanager.Run}, } logger *logrus.Logger log *base.LogObject