-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds package netwatcher to monitor network devices
It's built on top of a thinner wrapper around windows registry. When discussing WS035 we agreed not to expose the registry, as we might be required to find another, more specific, approach. Thus the API methods don't resemble the registry, but a device manangement API.
- Loading branch information
1 parent
b9591f1
commit 6edb944
Showing
8 changed files
with
538 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
package netwatcher | ||
|
||
// WithNetAdaptersAPIProvider sets the NetAdaptersAPIProvider to be used by the netWatcher.Subscribe(). | ||
func WithNetAdaptersAPIProvider(p NetAdaptersAPIProvider) Option { | ||
return func(o *options) { | ||
o.p = p | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
package netwatcher | ||
|
||
// defaultRepositoryFactory on Linux must delegate to the mock implementation. | ||
func defaultNetAdaptersAPIProvider() (NetAdaptersAPI, error) { | ||
panic("defaultNetAdaptersAPIProvider is not implemented on Linux without a mock") | ||
} |
49 changes: 49 additions & 0 deletions
49
windows-agent/internal/netwatcher/net_adapters_api_windows.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package netwatcher | ||
|
||
import ( | ||
"fmt" | ||
"path/filepath" | ||
|
||
"golang.org/x/sys/windows" | ||
"golang.org/x/sys/windows/registry" | ||
) | ||
|
||
// An implementaiton of the NetAdaptersAPI interface relying on the well-known registry path `HKLM:SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}` provided by the OS. | ||
type NetAdaptersAPIWindows struct { | ||
k registry.Key | ||
} | ||
|
||
func (h NetAdaptersAPIWindows) Close() { | ||
_ = h.k.Close() | ||
} | ||
|
||
func defaultNetAdaptersAPIProvider() (NetAdaptersAPI, error) { | ||
k, err := registry.OpenKey(windows.HKEY_LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}`, registry.READ) | ||
return NetAdaptersAPIWindows{k: k}, err | ||
} | ||
|
||
func (h NetAdaptersAPIWindows) ListDevices() ([]string, error) { | ||
return h.k.ReadSubKeyNames(-1) // This could potentially be implemented in terms of `GetInterfacesInfo`. | ||
} | ||
|
||
func (h NetAdaptersAPIWindows) GetDeviceConnectionName(guid string) (string, error) { | ||
// This could be implemented in terms of GetAdaptersAddresses. All other APIs considered would depend on the registry anyway. | ||
sk, err := registry.OpenKey(h.k, filepath.Join(guid, "Connection"), registry.READ) | ||
if err != nil { | ||
return "", fmt.Errorf("could not read the connection info from adapter GUID %s: %v", guid, err) | ||
} | ||
defer sk.Close() | ||
|
||
// Ignoring the registry value type trusting the OS will never create non-string values for this key. | ||
v, _, err := sk.GetStringValue("Name") | ||
if err != nil { | ||
return "", fmt.Errorf("could not read the connection name from adapter GUID %s: %v", guid, err) | ||
} | ||
return v, nil | ||
} | ||
|
||
func (h NetAdaptersAPIWindows) WaitForDeviceChanges() error { | ||
// This part could be implemented in terms of CM_Register_Notification, if we find a way to set a Win32 callback without relying on CGo. | ||
// Wait synchronuosly on notifications if a subkey is added or deleted, or changes to a value of the key, including adding or deleting a value, or changing an existing value. | ||
return windows.RegNotifyChangeKeyValue(windows.Handle(h.k), true, windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(0), false) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
// Package netwatcher allows susbcribing to the addition of network adapters on the host. | ||
package netwatcher | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"slices" | ||
"sort" | ||
"strings" | ||
|
||
log "github.com/canonical/ubuntu-pro-for-wsl/common/grpc/logstreamer" | ||
"github.com/google/uuid" | ||
) | ||
|
||
// NewAdapterCallback is called when new network adapters are added on the host. | ||
// It must return true to continue receiving notifications or false to stop the subscription. | ||
type NewAdapterCallback func(adapterNames []string) bool | ||
|
||
// NetWatcher represents a subscription to events of network adapters added on the host. | ||
type NetWatcher struct { | ||
ctx context.Context | ||
cancel context.CancelFunc | ||
callback NewAdapterCallback | ||
api NetAdaptersAPI | ||
|
||
cache []string | ||
|
||
// err is a channel through which we join the current waiting goroutine. | ||
err chan error | ||
} | ||
|
||
// NetAdaptersAPI is an interface for interacting with the network adapters on the host. | ||
type NetAdaptersAPI interface { | ||
// Close releases the resources associated with this object and cancels any outstanding wait operation. | ||
Close() | ||
|
||
// ListDevices returns the GUIDs of the network adapters on the host. | ||
ListDevices() ([]string, error) | ||
|
||
// GetDeviceConnectionName returns the connection name of the network adapter with the given GUID. | ||
GetDeviceConnectionName(guid string) (string, error) | ||
|
||
// WaitForChanges blocks the caller until the system triggers a notification of changes to the network adapters. | ||
// It returns nil if the notification is triggered or an error if the context is cancelled or an error occurs. | ||
// The wait is cancellable by calling Close(). | ||
WaitForDeviceChanges() error | ||
} | ||
|
||
type NetAdaptersAPIProvider func() (NetAdaptersAPI, error) | ||
type options struct { | ||
p NetAdaptersAPIProvider | ||
} | ||
|
||
type Option func(*options) | ||
|
||
var defaultOptions = options{ | ||
p: defaultNetAdaptersAPIProvider, | ||
} | ||
|
||
// Subscribe subscribes to the addition of network adapters on the host, calling the provided callback | ||
// with a slice of new adapter names discovered by the time the OS triggers the notification. | ||
func Subscribe(ctx context.Context, callback NewAdapterCallback, o ...Option) (*NetWatcher, error) { | ||
opts := defaultOptions | ||
for _, opt := range o { | ||
opt(&opts) | ||
} | ||
|
||
api, err := opts.p() | ||
if err != nil { | ||
return nil, fmt.Errorf("could not initialize the network adapter API: %v", err) | ||
} | ||
|
||
current, err := listAdapters(api) | ||
if err != nil { | ||
return nil, fmt.Errorf("could not get the current list of network adapters: %v", err) | ||
} | ||
|
||
nctx, cancel := context.WithCancel(ctx) | ||
// Ensures that the network adapter repository is closed when the context is cancelled so we don't need to do it explicitely. | ||
Check failure on line 79 in windows-agent/internal/netwatcher/netwatcher.go GitHub Actions / Go Quality checks (ubuntu, windows-agent)
|
||
context.AfterFunc(nctx, api.Close) | ||
n := &NetWatcher{ | ||
api: api, | ||
ctx: nctx, | ||
cancel: cancel, | ||
callback: callback, | ||
err: make(chan error, 1), | ||
cache: current, | ||
} | ||
|
||
go func() { | ||
defer close(n.err) | ||
|
||
err := n.start() | ||
n.err <- err | ||
log.Debugf(context.Background(), "stopped monitoring network adapters: %v", err) | ||
}() | ||
return n, nil | ||
} | ||
|
||
// Stop blocks the caller until the subscription to the addition of network adapters on the host is stopped. | ||
func (n *NetWatcher) Stop() error { | ||
n.cancel() | ||
|
||
// joins the goroutine that is waiting for network adapter changes. | ||
return <-n.err | ||
} | ||
|
||
// notify notifies the subscriber of the new network adapters added on the host. | ||
// It returns true if the subscription should continue. | ||
func (n *NetWatcher) notify() bool { | ||
// reloads the list of network adapters and their connection names from the registry | ||
current, err := listAdapters(n.api) | ||
if err != nil { | ||
log.Errorf(n.ctx, "could not get the current list of network adapters: %v", err) | ||
return true | ||
} | ||
// detects which network adapter was added, i.e. are in the current list but not in the cached list. | ||
added := difference(current, n.cache) | ||
if len(added) == 0 { | ||
return true | ||
} | ||
// updates the cache with the current list of network adapters. | ||
n.cache = current | ||
|
||
// finally calls the subscriber with the names of the new network adapters. | ||
return n.callback(added) | ||
} | ||
|
||
// start blocks a new goroutine on system notifications about network adapters on the host and notifies the subscriber, | ||
// while ensuring that this object's context cancellation is respected. | ||
func (n *NetWatcher) start() error { | ||
// Intentionally not closed to prevent potential panics due sending to a closed channel. | ||
waitCh := make(chan error) | ||
|
||
for { | ||
go func() { | ||
if err := n.api.WaitForDeviceChanges(); err != nil { | ||
waitCh <- fmt.Errorf("could not wait for network devices changes: %v", err) | ||
return | ||
} | ||
waitCh <- nil | ||
}() | ||
|
||
select { | ||
case <-n.ctx.Done(): | ||
return n.ctx.Err() | ||
case err := <-waitCh: | ||
if err != nil { | ||
return err | ||
} | ||
if !n.notify() { | ||
return nil | ||
} | ||
} | ||
|
||
} | ||
} | ||
|
||
// Provides the current list of network adapters by their connection names as seen in the output of commands such as `ipconfig /all`. | ||
func listAdapters(api NetAdaptersAPI) ([]string, error) { | ||
guids, err := api.ListDevices() | ||
if err != nil { | ||
return nil, fmt.Errorf("could not list network adapter GUIDs: %v", err) | ||
} | ||
|
||
// Filter out the entries that are not valid UUIDs. | ||
// When using the registry, there is at least one additional subkey named "Descriptions", which is not useful for this purpose. | ||
adapterGuids := filter(guids, func(guid string) bool { | ||
_, err := uuid.Parse(guid) | ||
return err == nil | ||
}) | ||
|
||
adapterNames := make([]string, 0, len(adapterGuids)) | ||
for _, guid := range adapterGuids { | ||
// Retrieves the connection name of the network adapter with the given GUID, which matches the device's Friendly Name. | ||
name, err := api.GetDeviceConnectionName(guid) | ||
if err != nil { | ||
return nil, err | ||
} | ||
adapterNames = append(adapterNames, name) | ||
} | ||
|
||
slices.Sort(adapterNames) | ||
return adapterNames, nil | ||
} | ||
|
||
// Given two sorted slices of strings, returns the elements that are in the first slice but not in the second. | ||
func difference(a, b []string) []string { | ||
len := len(b) | ||
if len == 0 { | ||
return a | ||
} | ||
|
||
diff := make([]string, 0) | ||
for _, v := range a { | ||
pos, found := sort.Find(len, func(i int) int { | ||
return strings.Compare(v, b[i]) | ||
}) | ||
if !found || pos == len { | ||
diff = append(diff, v) | ||
} | ||
} | ||
return diff | ||
} | ||
|
||
// Given a slice of strings, returns a new slice containing only the elements for which the predicate returns true. | ||
func filter(s []string, predicate func(string) bool) []string { | ||
res := make([]string, 0, len(s)) | ||
for _, v := range s { | ||
if predicate(v) { | ||
res = append(res, v) | ||
} | ||
} | ||
return res | ||
} |
Oops, something went wrong.