From 6edb944e4aefa1d0cfb88faa10180b4404d2dc9f Mon Sep 17 00:00:00 2001 From: Carlos Nihelton Date: Tue, 3 Sep 2024 15:46:42 -0300 Subject: [PATCH] Adds package netwatcher to monitor network devices It's built on top of a thinner wrapper around windows registry. When discussing WS035 we agreed not to expose the registry, as we might be required to find another, more specific, approach. Thus the API methods don't resemble the registry, but a device manangement API. --- windows-agent/internal/daemon/daemon.go | 29 ++- .../internal/netwatcher/export_test.go | 8 + .../netwatcher/net_adapters_api_linux.go | 6 + .../netwatcher/net_adapters_api_windows.go | 49 ++++ .../internal/netwatcher/netwatcher.go | 215 ++++++++++++++++++ .../internal/netwatcher/netwatcher_test.go | 136 +++++++++++ .../net_adapters_mock_api.go | 72 ++++++ .../netwatchertestutils.go | 24 ++ 8 files changed, 538 insertions(+), 1 deletion(-) create mode 100644 windows-agent/internal/netwatcher/export_test.go create mode 100644 windows-agent/internal/netwatcher/net_adapters_api_linux.go create mode 100644 windows-agent/internal/netwatcher/net_adapters_api_windows.go create mode 100644 windows-agent/internal/netwatcher/netwatcher.go create mode 100644 windows-agent/internal/netwatcher/netwatcher_test.go create mode 100644 windows-agent/internal/netwatcher/netwatchertestutils/net_adapters_mock_api.go create mode 100644 windows-agent/internal/netwatcher/netwatchertestutils/netwatchertestutils.go diff --git a/windows-agent/internal/daemon/daemon.go b/windows-agent/internal/daemon/daemon.go index aa951a798..9bc0af749 100644 --- a/windows-agent/internal/daemon/daemon.go +++ b/windows-agent/internal/daemon/daemon.go @@ -7,10 +7,12 @@ import ( "net" "os" "path/filepath" + "strings" "github.com/canonical/ubuntu-pro-for-wsl/common" log "github.com/canonical/ubuntu-pro-for-wsl/common/grpc/logstreamer" "github.com/canonical/ubuntu-pro-for-wsl/common/i18n" + "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher" "github.com/ubuntu/decorate" "google.golang.org/grpc" ) @@ -36,6 +38,8 @@ type Daemon struct { registerer GRPCServiceRegisterer grpcServer *grpc.Server + + netSubs *netwatcher.NetWatcher } // New returns an new, initialized daemon server that is ready to register GRPC services. @@ -111,6 +115,9 @@ func (d *Daemon) cleaup() { d.grpcServer = nil close(d.err) close(d.stopped) + if d.netSubs != nil { + d.netSubs.Stop() + } } // Quit gracefully quits listening loop and stops the grpc server. @@ -182,9 +189,29 @@ func (d *Daemon) serve(ctx context.Context, args ...Option) (err error) { wslNetAvailable := true wslIP, err := getWslIP(ctx, args...) if err != nil { - log.Warningf(ctx, "could not get the WSL adapter IP: %v", err) wslNetAvailable = false wslIP = net.IPv4(127, 0, 0, 1) + + log.Warningf(ctx, "Daemon: could not get the WSL adapter IP: %v. Starting network monitoring", err) + n, err := netwatcher.Subscribe(ctx, func(added []string) bool { + for _, adapter := range added { + if strings.Contains(adapter, "(WSL") { + log.Warningf(ctx, "Daemon: new adapter detected: %s", adapter) + d.restart(ctx) + return false + } + } + + // Not found yet, let's keep monitoring. + return true + }) + + if err != nil { + log.Errorf(ctx, "Daemon: could not start network monitoring: %v", err) + // should we return (and not proceed with serving) instead? + } else { + d.netSubs = n + } } var cfg net.ListenConfig diff --git a/windows-agent/internal/netwatcher/export_test.go b/windows-agent/internal/netwatcher/export_test.go new file mode 100644 index 000000000..a2309827f --- /dev/null +++ b/windows-agent/internal/netwatcher/export_test.go @@ -0,0 +1,8 @@ +package netwatcher + +// WithNetAdaptersAPIProvider sets the NetAdaptersAPIProvider to be used by the netWatcher.Subscribe(). +func WithNetAdaptersAPIProvider(p NetAdaptersAPIProvider) Option { + return func(o *options) { + o.p = p + } +} diff --git a/windows-agent/internal/netwatcher/net_adapters_api_linux.go b/windows-agent/internal/netwatcher/net_adapters_api_linux.go new file mode 100644 index 000000000..47b325ee2 --- /dev/null +++ b/windows-agent/internal/netwatcher/net_adapters_api_linux.go @@ -0,0 +1,6 @@ +package netwatcher + +// defaultRepositoryFactory on Linux must delegate to the mock implementation. +func defaultNetAdaptersAPIProvider() (NetAdaptersAPI, error) { + panic("defaultNetAdaptersAPIProvider is not implemented on Linux without a mock") +} diff --git a/windows-agent/internal/netwatcher/net_adapters_api_windows.go b/windows-agent/internal/netwatcher/net_adapters_api_windows.go new file mode 100644 index 000000000..491a0c9fb --- /dev/null +++ b/windows-agent/internal/netwatcher/net_adapters_api_windows.go @@ -0,0 +1,49 @@ +package netwatcher + +import ( + "fmt" + "path/filepath" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" +) + +// An implementaiton of the NetAdaptersAPI interface relying on the well-known registry path `HKLM:SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}` provided by the OS. +type NetAdaptersAPIWindows struct { + k registry.Key +} + +func (h NetAdaptersAPIWindows) Close() { + _ = h.k.Close() +} + +func defaultNetAdaptersAPIProvider() (NetAdaptersAPI, error) { + k, err := registry.OpenKey(windows.HKEY_LOCAL_MACHINE, `SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}`, registry.READ) + return NetAdaptersAPIWindows{k: k}, err +} + +func (h NetAdaptersAPIWindows) ListDevices() ([]string, error) { + return h.k.ReadSubKeyNames(-1) // This could potentially be implemented in terms of `GetInterfacesInfo`. +} + +func (h NetAdaptersAPIWindows) GetDeviceConnectionName(guid string) (string, error) { + // This could be implemented in terms of GetAdaptersAddresses. All other APIs considered would depend on the registry anyway. + sk, err := registry.OpenKey(h.k, filepath.Join(guid, "Connection"), registry.READ) + if err != nil { + return "", fmt.Errorf("could not read the connection info from adapter GUID %s: %v", guid, err) + } + defer sk.Close() + + // Ignoring the registry value type trusting the OS will never create non-string values for this key. + v, _, err := sk.GetStringValue("Name") + if err != nil { + return "", fmt.Errorf("could not read the connection name from adapter GUID %s: %v", guid, err) + } + return v, nil +} + +func (h NetAdaptersAPIWindows) WaitForDeviceChanges() error { + // This part could be implemented in terms of CM_Register_Notification, if we find a way to set a Win32 callback without relying on CGo. + // Wait synchronuosly on notifications if a subkey is added or deleted, or changes to a value of the key, including adding or deleting a value, or changing an existing value. + return windows.RegNotifyChangeKeyValue(windows.Handle(h.k), true, windows.REG_NOTIFY_CHANGE_NAME|windows.REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(0), false) +} diff --git a/windows-agent/internal/netwatcher/netwatcher.go b/windows-agent/internal/netwatcher/netwatcher.go new file mode 100644 index 000000000..f642060d9 --- /dev/null +++ b/windows-agent/internal/netwatcher/netwatcher.go @@ -0,0 +1,215 @@ +// Package netwatcher allows susbcribing to the addition of network adapters on the host. +package netwatcher + +import ( + "context" + "fmt" + "slices" + "sort" + "strings" + + log "github.com/canonical/ubuntu-pro-for-wsl/common/grpc/logstreamer" + "github.com/google/uuid" +) + +// NewAdapterCallback is called when new network adapters are added on the host. +// It must return true to continue receiving notifications or false to stop the subscription. +type NewAdapterCallback func(adapterNames []string) bool + +// NetWatcher represents a subscription to events of network adapters added on the host. +type NetWatcher struct { + ctx context.Context + cancel context.CancelFunc + callback NewAdapterCallback + api NetAdaptersAPI + + cache []string + + // err is a channel through which we join the current waiting goroutine. + err chan error +} + +// NetAdaptersAPI is an interface for interacting with the network adapters on the host. +type NetAdaptersAPI interface { + // Close releases the resources associated with this object and cancels any outstanding wait operation. + Close() + + // ListDevices returns the GUIDs of the network adapters on the host. + ListDevices() ([]string, error) + + // GetDeviceConnectionName returns the connection name of the network adapter with the given GUID. + GetDeviceConnectionName(guid string) (string, error) + + // WaitForChanges blocks the caller until the system triggers a notification of changes to the network adapters. + // It returns nil if the notification is triggered or an error if the context is cancelled or an error occurs. + // The wait is cancellable by calling Close(). + WaitForDeviceChanges() error +} + +type NetAdaptersAPIProvider func() (NetAdaptersAPI, error) +type options struct { + p NetAdaptersAPIProvider +} + +type Option func(*options) + +var defaultOptions = options{ + p: defaultNetAdaptersAPIProvider, +} + +// Subscribe subscribes to the addition of network adapters on the host, calling the provided callback +// with a slice of new adapter names discovered by the time the OS triggers the notification. +func Subscribe(ctx context.Context, callback NewAdapterCallback, o ...Option) (*NetWatcher, error) { + opts := defaultOptions + for _, opt := range o { + opt(&opts) + } + + api, err := opts.p() + if err != nil { + return nil, fmt.Errorf("could not initialize the network adapter API: %v", err) + } + + current, err := listAdapters(api) + if err != nil { + return nil, fmt.Errorf("could not get the current list of network adapters: %v", err) + } + + nctx, cancel := context.WithCancel(ctx) + // Ensures that the network adapter repository is closed when the context is cancelled so we don't need to do it explicitely. + context.AfterFunc(nctx, api.Close) + n := &NetWatcher{ + api: api, + ctx: nctx, + cancel: cancel, + callback: callback, + err: make(chan error, 1), + cache: current, + } + + go func() { + defer close(n.err) + + err := n.start() + n.err <- err + log.Debugf(context.Background(), "stopped monitoring network adapters: %v", err) + }() + return n, nil +} + +// Stop blocks the caller until the subscription to the addition of network adapters on the host is stopped. +func (n *NetWatcher) Stop() error { + n.cancel() + + // joins the goroutine that is waiting for network adapter changes. + return <-n.err +} + +// notify notifies the subscriber of the new network adapters added on the host. +// It returns true if the subscription should continue. +func (n *NetWatcher) notify() bool { + // reloads the list of network adapters and their connection names from the registry + current, err := listAdapters(n.api) + if err != nil { + log.Errorf(n.ctx, "could not get the current list of network adapters: %v", err) + return true + } + // detects which network adapter was added, i.e. are in the current list but not in the cached list. + added := difference(current, n.cache) + if len(added) == 0 { + return true + } + // updates the cache with the current list of network adapters. + n.cache = current + + // finally calls the subscriber with the names of the new network adapters. + return n.callback(added) +} + +// start blocks a new goroutine on system notifications about network adapters on the host and notifies the subscriber, +// while ensuring that this object's context cancellation is respected. +func (n *NetWatcher) start() error { + // Intentionally not closed to prevent potential panics due sending to a closed channel. + waitCh := make(chan error) + + for { + go func() { + if err := n.api.WaitForDeviceChanges(); err != nil { + waitCh <- fmt.Errorf("could not wait for network devices changes: %v", err) + return + } + waitCh <- nil + }() + + select { + case <-n.ctx.Done(): + return n.ctx.Err() + case err := <-waitCh: + if err != nil { + return err + } + if !n.notify() { + return nil + } + } + + } +} + +// Provides the current list of network adapters by their connection names as seen in the output of commands such as `ipconfig /all`. +func listAdapters(api NetAdaptersAPI) ([]string, error) { + guids, err := api.ListDevices() + if err != nil { + return nil, fmt.Errorf("could not list network adapter GUIDs: %v", err) + } + + // Filter out the entries that are not valid UUIDs. + // When using the registry, there is at least one additional subkey named "Descriptions", which is not useful for this purpose. + adapterGuids := filter(guids, func(guid string) bool { + _, err := uuid.Parse(guid) + return err == nil + }) + + adapterNames := make([]string, 0, len(adapterGuids)) + for _, guid := range adapterGuids { + // Retrieves the connection name of the network adapter with the given GUID, which matches the device's Friendly Name. + name, err := api.GetDeviceConnectionName(guid) + if err != nil { + return nil, err + } + adapterNames = append(adapterNames, name) + } + + slices.Sort(adapterNames) + return adapterNames, nil +} + +// Given two sorted slices of strings, returns the elements that are in the first slice but not in the second. +func difference(a, b []string) []string { + len := len(b) + if len == 0 { + return a + } + + diff := make([]string, 0) + for _, v := range a { + pos, found := sort.Find(len, func(i int) int { + return strings.Compare(v, b[i]) + }) + if !found || pos == len { + diff = append(diff, v) + } + } + return diff +} + +// Given a slice of strings, returns a new slice containing only the elements for which the predicate returns true. +func filter(s []string, predicate func(string) bool) []string { + res := make([]string, 0, len(s)) + for _, v := range s { + if predicate(v) { + res = append(res, v) + } + } + return res +} diff --git a/windows-agent/internal/netwatcher/netwatcher_test.go b/windows-agent/internal/netwatcher/netwatcher_test.go new file mode 100644 index 000000000..112659458 --- /dev/null +++ b/windows-agent/internal/netwatcher/netwatcher_test.go @@ -0,0 +1,136 @@ +package netwatcher_test + +import ( + "context" + "errors" + "maps" + "testing" + "time" + + "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher" + "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher/netwatchertestutils" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func init() { + netwatchertestutils.DefaultNetAdaptersAPIProviderToMock() +} + +func TestSubscribe(t *testing.T) { + t.Parallel() + + mockError := errors.New("mock error") + + before := map[string]string{ + uuid.New().String(): "conn0", + uuid.New().String(): "conn1", + uuid.New().String(): "conn2", + } + + after := map[string]string{ + uuid.New().String(): "new", + "not a guid": "yet_another_new", + } + + maps.Copy(after, before) + + testcases := map[string]struct { + initError error + listDevicesError error + getConnNameError error + waitError error + + ctxCancel bool + devicesUnchanged bool + + wantErr bool + wantNoCallback bool + wantName string + }{ + "Success": {}, + + "When waiting for changes fails": {waitError: mockError, wantNoCallback: true}, + "When the context is cancelled while waiting": {ctxCancel: true, wantNoCallback: true}, + "When OS triggers a notification without device changes": {devicesUnchanged: true, wantNoCallback: true}, + + "When initializing the API fails": {initError: mockError, wantErr: true}, + "When listing devices fails": {listDevicesError: mockError, wantErr: true}, + "When getting connection name fails": {getConnNameError: mockError, wantErr: true}, + } + + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + switch tc.wantName { + case "": + tc.wantName = "new" + case "-": + tc.wantName = "" + + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + initAPI := func() (netwatcher.NetAdaptersAPI, error) { + if tc.initError != nil { + return nil, tc.initError + } + + a := after + if tc.devicesUnchanged { + a = before + } + return &netwatchertestutils.NetAdapterMockAPI{ + Before: before, + After: a, + ListDevicesError: tc.listDevicesError, + GetDeviceConnectionNameError: tc.getConnNameError, + WaitForDeviceChangesImpl: func() error { + if tc.ctxCancel { + cancel() + <-ctx.Done() + return ctx.Err() + } + // Introduces some asynchrony to the test. + <-time.After(50 * time.Millisecond) + return tc.waitError + }, + }, nil + } + + added := make(chan string, 1) + defer close(added) + callback := func(adapterNames []string) bool { + added <- adapterNames[0] + return false + } + + n, err := netwatcher.Subscribe(ctx, callback, netwatcher.WithNetAdaptersAPIProvider(initAPI)) + + if tc.wantErr { + require.Error(t, err, "Subscribe should have failed") + return + } + require.NoError(t, err, "Subscribe should have succeeded") + + select { + case res := <-added: + require.Equal(t, tc.wantName, res, "unexpected new network adapter") + case <-time.After(200 * time.Millisecond): + if !tc.wantNoCallback { + require.Fail(t, "timeout waiting for new network adapter") + } + } + + // Collect the error reported by the wait operation. + err = n.Stop() + if tc.waitError != nil || tc.wantNoCallback || tc.ctxCancel { + require.Error(t, err, "Stop should have failed") + return + } + require.NoError(t, err, "Stop should have succeeded") + }) + } +} diff --git a/windows-agent/internal/netwatcher/netwatchertestutils/net_adapters_mock_api.go b/windows-agent/internal/netwatcher/netwatchertestutils/net_adapters_mock_api.go new file mode 100644 index 000000000..e1d61de5b --- /dev/null +++ b/windows-agent/internal/netwatcher/netwatchertestutils/net_adapters_mock_api.go @@ -0,0 +1,72 @@ +package netwatchertestutils + +import ( + "errors" + "fmt" + "sync/atomic" + + "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher" + "golang.org/x/exp/maps" // When migrate to Go 1.23 use "maps" instead. +) + +// Implements the NetworkAdapterRepository interface for testing purposes. +type NetAdapterMockAPI struct { + Before, After map[string]string + ListDevicesError error + + GetDeviceConnectionNameError error + WaitForDeviceChangesImpl func() error + + listDevicesCalledFirstTime atomic.Bool +} + +// Close releases the resources associated with this object and cancels any outstanding wait operation. +func (m *NetAdapterMockAPI) Close() {} + +// ListAdapterGUIDs returns the GUIDs of the network adapters on the host. +func (m *NetAdapterMockAPI) ListDevices() ([]string, error) { + if m.ListDevicesError != nil { + return nil, m.ListDevicesError + } + if !m.listDevicesCalledFirstTime.Load() { + m.listDevicesCalledFirstTime.Store(true) + return maps.Keys(m.Before), nil + } + return maps.Keys(m.After), nil +} + +// ConnectionName returns the connection name of the network adapter with the given GUID. +func (m *NetAdapterMockAPI) GetDeviceConnectionName(guid string) (string, error) { + if m.GetDeviceConnectionNameError != nil { + return "", m.GetDeviceConnectionNameError + } + + if m.Before == nil || m.After == nil { + return "", errors.New("not implemented") + } + if !m.listDevicesCalledFirstTime.Load() { + if name, ok := m.Before[guid]; ok { + return name, nil + } + return "", fmt.Errorf("device %s not found", guid) + } + + if name, ok := m.After[guid]; ok { + return name, nil + } + return "", fmt.Errorf("device %s not found", guid) +} + +// WaitForChanges blocks the caller until the system triggers a notification of changes to the network adapters. +func (m *NetAdapterMockAPI) WaitForDeviceChanges() error { + if m.WaitForDeviceChangesImpl != nil { + return m.WaitForDeviceChangesImpl() + } + + return errors.New("not implemented") +} + +// NewNetAdapterRepositoryMock creates a new instance of the netAdapterRepositoryMock. +func newNetAdapterMockAPI() (netwatcher.NetAdaptersAPI, error) { + return &NetAdapterMockAPI{}, nil +} diff --git a/windows-agent/internal/netwatcher/netwatchertestutils/netwatchertestutils.go b/windows-agent/internal/netwatcher/netwatchertestutils/netwatchertestutils.go new file mode 100644 index 000000000..eb4bde8de --- /dev/null +++ b/windows-agent/internal/netwatcher/netwatchertestutils/netwatchertestutils.go @@ -0,0 +1,24 @@ +// Package netwatchertestutils exports test helpers to be used in other packages that need to change internal behaviors of the netwatcher. +package netwatchertestutils + +import ( + //nolint:revive,nolintlint // needed for go:linkname, but only used in tests. nolintlint as false positive then. + _ "unsafe" + + "github.com/canonical/ubuntu-pro-for-wsl/common/testdetection" + "github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher" +) + +var ( + //go:linkname defaultOptions github.com/canonical/ubuntu-pro-for-wsl/windows-agent/internal/netwatcher.defaultOptions + defaultOptions struct { + p netwatcher.NetAdaptersAPIProvider + } +) + +// DefaultNetworkDetectionToMock sets the default options for the daemon package with mocks for success of upper level tests. +func DefaultNetAdaptersAPIProviderToMock() { + testdetection.MustBeTesting() + + defaultOptions.p = newNetAdapterMockAPI +}