diff --git a/bus/client/webhooks.go b/bus/client/webhooks.go index 769d1cf57..833c7c7e1 100644 --- a/bus/client/webhooks.go +++ b/bus/client/webhooks.go @@ -13,16 +13,12 @@ func (c *Client) BroadcastAction(ctx context.Context, action webhooks.Event) err return err } -// DeleteWebhook deletes the webhook with the given ID. -func (c *Client) DeleteWebhook(ctx context.Context, url, module, event string) error { - return c.c.POST("/webhook/delete", webhooks.Webhook{ - URL: url, - Module: module, - Event: event, - }, nil) +// UnregisterWebhook unregisters the given webhook. +func (c *Client) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + return c.c.POST("/webhook/delete", webhook, nil) } -// RegisterWebhook registers a new webhook for the given URL. +// RegisterWebhook registers the given webhook. func (c *Client) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { err := c.c.WithContext(ctx).POST("/webhooks", webhook, nil) return err diff --git a/internal/worker/cache.go b/internal/worker/cache.go index f0d59efe4..dc8ae6357 100644 --- a/internal/worker/cache.go +++ b/internal/worker/cache.go @@ -75,14 +75,13 @@ type ( Bus interface { Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) GougingParams(ctx context.Context) (api.GougingParams, error) - RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error } WorkerCache interface { DownloadContracts(ctx context.Context) ([]api.ContractMetadata, error) GougingParams(ctx context.Context) (api.GougingParams, error) HandleEvent(event webhooks.Event) error - Initialize(ctx context.Context, workerAPI string, opts ...webhooks.HeaderOption) error + Subscribe(e EventSubscriber) error } ) @@ -92,8 +91,8 @@ type cache struct { cache *memoryCache logger *zap.SugaredLogger - mu sync.Mutex - ready bool + mu sync.Mutex + readyChan chan struct{} } func NewCache(b Bus, logger *zap.Logger) WorkerCache { @@ -197,33 +196,27 @@ func (c *cache) HandleEvent(event webhooks.Event) (err error) { return } -func (c *cache) Initialize(ctx context.Context, workerAPI string, webhookOpts ...webhooks.HeaderOption) error { - eventsURL := fmt.Sprintf("%s/events", workerAPI) - headers := make(map[string]string) - for _, opt := range webhookOpts { - opt(headers) +func (c *cache) Subscribe(e EventSubscriber) (err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.readyChan != nil { + return fmt.Errorf("already subscribed") } - for _, wh := range []webhooks.Webhook{ - api.WebhookConsensusUpdate(eventsURL, headers), - api.WebhookContractArchive(eventsURL, headers), - api.WebhookContractRenew(eventsURL, headers), - api.WebhookHostUpdate(eventsURL, headers), - api.WebhookSettingUpdate(eventsURL, headers), - } { - if err := c.b.RegisterWebhook(ctx, wh); err != nil { - return fmt.Errorf("failed to register webhook '%s', err: %v", wh, err) - } + + c.readyChan, err = e.AddEventHandler(c.logger.Desugar().Name(), c) + if err != nil { + return fmt.Errorf("failed to subscribe the worker cache, error: %v", err) } - c.mu.Lock() - c.ready = true - c.mu.Unlock() return nil } func (c *cache) isReady() bool { - c.mu.Lock() - defer c.mu.Unlock() - return c.ready + select { + case <-c.readyChan: + return true + default: + } + return false } func (c *cache) handleConsensusUpdate(event api.EventConsensusUpdate) { diff --git a/internal/worker/cache_test.go b/internal/worker/cache_test.go index 1f3261d56..9bc8d682d 100644 --- a/internal/worker/cache_test.go +++ b/internal/worker/cache_test.go @@ -26,7 +26,22 @@ func (m *mockBus) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api. func (m *mockBus) GougingParams(ctx context.Context) (api.GougingParams, error) { return m.gougingParams, nil } -func (m *mockBus) RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error { + +type mockEventSubscriber struct { + readyChan chan struct{} +} + +func (m *mockEventSubscriber) AddEventHandler(id string, h EventHandler) (chan struct{}, error) { + return m.readyChan, nil +} + +func (m *mockEventSubscriber) ProcessEvent(event webhooks.Event) {} + +func (m *mockEventSubscriber) Register(ctx context.Context, eventURL string, opts ...webhooks.HeaderOption) error { + return nil +} + +func (m *mockEventSubscriber) Shutdown(ctx context.Context) error { return nil } @@ -57,7 +72,13 @@ func TestWorkerCache(t *testing.T) { // create mock bus and cache c, b, mc := newTestCache(zap.New(observedZapCore)) - // assert using cache before it's initialized prints a warning + // create mock event subscriber + m := &mockEventSubscriber{readyChan: make(chan struct{})} + + // subscribe cache to event subscriber + c.Subscribe(m) + + // assert using cache before it's ready prints a warning contracts, err := c.DownloadContracts(context.Background()) if err != nil { t.Fatal(err) @@ -84,10 +105,8 @@ func TestWorkerCache(t *testing.T) { t.Fatal("expected error message to contain 'cache is not ready yet', got", lines[0].Message) } - // initialize the cache - if err := c.Initialize(context.Background(), ""); err != nil { - t.Fatal(err) - } + // close the ready channel + close(m.readyChan) // fetch contracts & gouging params so they're cached _, err = c.DownloadContracts(context.Background()) diff --git a/internal/worker/events.go b/internal/worker/events.go new file mode 100644 index 000000000..eb98018d3 --- /dev/null +++ b/internal/worker/events.go @@ -0,0 +1,194 @@ +package worker + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" +) + +var ( + alertWebhookRegistrationFailedID = alerts.RandomAlertID() // constant until restarted +) + +type ( + EventSubscriber interface { + AddEventHandler(id string, h EventHandler) (chan struct{}, error) + ProcessEvent(event webhooks.Event) + Register(ctx context.Context, eventURL string, opts ...webhooks.HeaderOption) error + Shutdown(context.Context) error + } + + EventHandler interface { + HandleEvent(event webhooks.Event) error + Subscribe(e EventSubscriber) error + } + + WebhookManager interface { + RegisterWebhook(ctx context.Context, wh webhooks.Webhook) error + UnregisterWebhook(ctx context.Context, wh webhooks.Webhook) error + } +) + +type ( + eventSubscriber struct { + alerts alerts.Alerter + webhooks WebhookManager + logger *zap.SugaredLogger + + registerInterval time.Duration + + mu sync.Mutex + handlers map[string]EventHandler + registered []webhooks.Webhook + registeredChan chan struct{} + } +) + +func NewEventSubscriber(a alerts.Alerter, w WebhookManager, l *zap.Logger, registerInterval time.Duration) EventSubscriber { + return &eventSubscriber{ + alerts: a, + webhooks: w, + logger: l.Sugar().Named("events"), + + registeredChan: make(chan struct{}), + + handlers: make(map[string]EventHandler), + registerInterval: registerInterval, + } +} + +func (e *eventSubscriber) AddEventHandler(id string, h EventHandler) (chan struct{}, error) { + e.mu.Lock() + defer e.mu.Unlock() + _, ok := e.handlers[id] + if ok { + return nil, fmt.Errorf("subscriber with id %v already exists", id) + } + e.handlers[id] = h + + return e.registeredChan, nil +} + +func (e *eventSubscriber) ProcessEvent(event webhooks.Event) { + log := e.logger.With( + zap.String("module", event.Module), + zap.String("event", event.Event), + ) + + for id, s := range e.handlers { + if err := s.HandleEvent(event); err != nil { + log.Errorw("failed to handle event", + zap.Error(err), + zap.String("subscriber", id), + ) + } else { + log.Debugw("handled event", + zap.String("subscriber", id), + ) + } + } +} + +func (e *eventSubscriber) Register(ctx context.Context, eventsURL string, opts ...webhooks.HeaderOption) error { + select { + case <-e.registeredChan: + return fmt.Errorf("already registered") // developer error + default: + } + + // prepare headers + headers := make(map[string]string) + for _, opt := range opts { + opt(headers) + } + + // prepare webhooks + webhooks := []webhooks.Webhook{ + api.WebhookConsensusUpdate(eventsURL, headers), + api.WebhookContractArchive(eventsURL, headers), + api.WebhookContractRenew(eventsURL, headers), + api.WebhookHostUpdate(eventsURL, headers), + api.WebhookSettingUpdate(eventsURL, headers), + } + + // try and register the webhooks in a loop + for { + err := e.registerWebhooks(ctx, webhooks) + if err == nil { + e.alerts.DismissAlerts(ctx, alertWebhookRegistrationFailedID) + break + } + + // alert on failure + e.alerts.RegisterAlert(ctx, newWebhookRegistrationFailedAlert(err)) + e.logger.Warnf("failed to register webhooks, retrying in %v", e.registerInterval) + + // sleep for a bit before trying again + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(e.registerInterval): + } + } + + return nil +} + +func (e *eventSubscriber) Shutdown(ctx context.Context) error { + e.mu.Lock() + defer e.mu.Unlock() + + // unregister webhooks + var errs []error + for _, wh := range e.registered { + if err := e.webhooks.UnregisterWebhook(ctx, wh); err != nil { + e.logger.Errorw("failed to unregister webhook", + zap.Error(err), + zap.Stringer("webhook", wh), + ) + errs = append(errs, err) + } + } + + return errors.Join(errs...) +} + +func (e *eventSubscriber) registerWebhooks(ctx context.Context, webhooks []webhooks.Webhook) error { + for _, wh := range webhooks { + if err := e.webhooks.RegisterWebhook(ctx, wh); err != nil { + e.logger.Errorw("failed to register webhook", + zap.Error(err), + zap.Stringer("webhook", wh), + ) + return err + } + } + + // save webhooks so we can unregister them on shutdown + e.mu.Lock() + e.registered = webhooks + e.mu.Unlock() + + // signal that we're registered + close(e.registeredChan) + return nil +} + +func newWebhookRegistrationFailedAlert(err error) alerts.Alert { + return alerts.Alert{ + ID: alertWebhookRegistrationFailedID, + Severity: alerts.SeverityCritical, + Message: "Worker failed to register webhooks", + Data: map[string]any{ + "error": err.Error(), + }, + Timestamp: time.Now(), + } +} diff --git a/internal/worker/events_test.go b/internal/worker/events_test.go new file mode 100644 index 000000000..3de028acd --- /dev/null +++ b/internal/worker/events_test.go @@ -0,0 +1,219 @@ +package worker + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/jape" + "go.sia.tech/renterd/alerts" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/webhooks" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +const testRegisterInterval = 100 * time.Millisecond + +type mockAlerter struct{} + +func (a *mockAlerter) Alerts(ctx context.Context, opts alerts.AlertsOpts) (alerts.AlertsResponse, error) { + return alerts.AlertsResponse{}, nil +} +func (a *mockAlerter) RegisterAlert(ctx context.Context, alert alerts.Alert) error { return nil } +func (a *mockAlerter) DismissAlerts(ctx context.Context, ids ...types.Hash256) error { return nil } + +type mockEventHandler struct { + id string + readyChan chan struct{} + + mu sync.Mutex + events []webhooks.Event +} + +func (s *mockEventHandler) Events() []webhooks.Event { + s.mu.Lock() + defer s.mu.Unlock() + return s.events +} + +func (s *mockEventHandler) HandleEvent(event webhooks.Event) error { + s.mu.Lock() + defer s.mu.Unlock() + + select { + case <-s.readyChan: + default: + return fmt.Errorf("subscriber not ready") + } + + s.events = append(s.events, event) + return nil +} + +func (s *mockEventHandler) Subscribe(e EventSubscriber) error { + s.readyChan, _ = e.AddEventHandler(s.id, s) + return nil +} + +type mockWebhookManager struct { + blockChan chan struct{} + + mu sync.Mutex + registered []webhooks.Webhook +} + +func (m *mockWebhookManager) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + <-m.blockChan + + m.mu.Lock() + defer m.mu.Unlock() + m.registered = append(m.registered, webhook) + return nil +} + +func (m *mockWebhookManager) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + m.mu.Lock() + defer m.mu.Unlock() + + for i, wh := range m.registered { + if wh.String() == webhook.String() { + m.registered = append(m.registered[:i], m.registered[i+1:]...) + return nil + } + } + return nil +} + +func (m *mockWebhookManager) Webhooks() []webhooks.Webhook { + m.mu.Lock() + defer m.mu.Unlock() + return m.registered +} + +func TestEventSubscriber(t *testing.T) { + // observe logs + observedZapCore, observedLogs := observer.New(zap.DebugLevel) + + // create mocks + a := &mockAlerter{} + w := &mockWebhookManager{blockChan: make(chan struct{})} + h := &mockEventHandler{id: t.Name()} + + // create event subscriber + s := NewEventSubscriber(a, w, zap.New(observedZapCore), testRegisterInterval) + + // subscribe the event handler + if err := h.Subscribe(s); err != nil { + t.Fatal(err) + } + + // setup a server + mux := jape.Mux(map[string]jape.Handler{"POST /events": func(jc jape.Context) { + var event webhooks.Event + if jc.Decode(&event) != nil { + return + } else if event.Event == webhooks.WebhookEventPing { + jc.ResponseWriter.WriteHeader(http.StatusOK) + return + } else { + s.ProcessEvent(event) + } + }}) + srv := httptest.NewServer(mux) + defer srv.Close() + + // register the subscriber + eventsURL := fmt.Sprintf("http://%v/events", srv.Listener.Addr().String()) + go func() { + if err := s.Register(context.Background(), eventsURL); err != nil { + t.Error(err) + } + }() + + // send an event before unblocking webhooks registration + err := sendEvent(eventsURL, webhooks.Event{Module: api.ModuleConsensus, Event: api.EventUpdate}) + if err != nil { + t.Fatal(err) + } + logs := observedLogs.TakeAll() + if len(logs) != 1 { + t.Fatal("expected 1 log, got", len(logs)) + } else if entry := logs[0]; entry.Message != "failed to handle event" || entry.ContextMap()["error"] != "subscriber not ready" { + t.Fatal("expected different log entry, got", entry) + } + + // unblock the webhooks registration + close(w.blockChan) + time.Sleep(testRegisterInterval) + + // assert webhook was registered + if webhooks := w.Webhooks(); len(webhooks) != 5 { + t.Fatal("expected 5 webhooks, got", len(webhooks)) + } + + // send the same event again + err = sendEvent(eventsURL, webhooks.Event{Module: api.ModuleConsensus, Event: api.EventUpdate}) + if err != nil { + t.Fatal(err) + } + logs = observedLogs.TakeAll() + if len(logs) != 1 { + t.Fatal("expected 1 log, got", len(logs)) + } else if entry := logs[0]; entry.Message != "handled event" || entry.ContextMap()["subscriber"] != t.Name() { + t.Fatal("expected different log entry, got", entry) + } + + // assert the subscriber handled the event + if events := h.Events(); len(events) != 1 { + t.Fatal("expected 1 event, got", len(events)) + } + + // shutdown event subscriber + err = s.Shutdown(context.Background()) + if err != nil { + t.Fatal(err) + } + + // assert webhook was unregistered + if webhooks := w.Webhooks(); len(webhooks) != 0 { + t.Fatal("expected 0 webhooks, got", len(webhooks)) + } +} + +func sendEvent(url string, event webhooks.Event) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + body, err := json.Marshal(event) + if err != nil { + return err + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return err + } + defer io.ReadAll(req.Body) // always drain body + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + errStr, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + return fmt.Errorf("Webhook returned unexpected status %v: %v", resp.StatusCode, string(errStr)) + } + return nil +} diff --git a/worker/client/client.go b/worker/client/client.go index c1ab8a70e..9abac4d0e 100644 --- a/worker/client/client.go +++ b/worker/client/client.go @@ -269,8 +269,8 @@ func (c *Client) UploadStats() (resp api.UploadStatsResponse, err error) { return } -// RegisterEvent register an event. -func (c *Client) RegisterEvent(ctx context.Context, e webhooks.Event) (err error) { +// NotifyEvent notifies the worker of an event. +func (c *Client) NotifyEvent(ctx context.Context, e webhooks.Event) (err error) { err = c.c.WithContext(ctx).POST("/events", e, nil) return } diff --git a/worker/mocks_test.go b/worker/mocks_test.go index 192f4c169..4c3929205 100644 --- a/worker/mocks_test.go +++ b/worker/mocks_test.go @@ -697,3 +697,7 @@ type webhookStoreMock struct{} func (*webhookStoreMock) RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { return nil } + +func (*webhookStoreMock) UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error { + return nil +} diff --git a/worker/worker.go b/worker/worker.go index a847af894..a6bfefbd4 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -157,6 +157,7 @@ type ( WebhookStore interface { RegisterWebhook(ctx context.Context, webhook webhooks.Webhook) error + UnregisterWebhook(ctx context.Context, webhook webhooks.Webhook) error } ConsensusState interface { @@ -208,6 +209,7 @@ type worker struct { masterKey [32]byte startTime time.Time + eventSubscriber iworker.EventSubscriber downloadManager *downloadManager uploadManager *uploadManager @@ -1234,25 +1236,6 @@ func (w *worker) idHandlerGET(jc jape.Context) { jc.Encode(w.id) } -func (w *worker) eventsHandler(jc jape.Context) { - var event webhooks.Event - if jc.Decode(&event) != nil { - return - } else if event.Event == webhooks.WebhookEventPing { - jc.ResponseWriter.WriteHeader(http.StatusOK) - return - } - - err := w.cache.HandleEvent(event) - if errors.Is(err, api.ErrUnknownEvent) { - jc.ResponseWriter.WriteHeader(http.StatusAccepted) - return - } else if err != nil { - jc.Error(err, http.StatusBadRequest) - return - } -} - func (w *worker) memoryGET(jc jape.Context) { jc.Encode(api.MemoryResponse{ Download: w.downloadManager.mm.Status(), @@ -1269,6 +1252,17 @@ func (w *worker) accountHandlerGET(jc jape.Context) { jc.Encode(account) } +func (w *worker) eventsHandlerPOST(jc jape.Context) { + var event webhooks.Event + if jc.Decode(&event) != nil { + return + } else if event.Event == webhooks.WebhookEventPing { + jc.ResponseWriter.WriteHeader(http.StatusOK) + } else { + w.eventSubscriber.ProcessEvent(event) + } +} + func (w *worker) stateHandlerGET(jc jape.Context) { jc.Encode(api.WorkerStateResponse{ ID: w.id, @@ -1304,14 +1298,15 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush return nil, errors.New("uploadMaxMemory cannot be 0") } + a := alerts.WithOrigin(b, fmt.Sprintf("worker.%s", id)) l = l.Named("worker").Named(id) - cache := iworker.NewCache(b, l) shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) w := &worker{ - alerts: alerts.WithOrigin(b, fmt.Sprintf("worker.%s", id)), + alerts: a, allowPrivateIPs: allowPrivateIPs, contractLockingDuration: contractLockingDuration, - cache: cache, + cache: iworker.NewCache(b, l), + eventSubscriber: iworker.NewEventSubscriber(a, b, l, 10*time.Second), id: id, bus: b, masterKey: masterKey, @@ -1339,7 +1334,7 @@ func (w *worker) Handler() http.Handler { "GET /account/:hostkey": w.accountHandlerGET, "GET /id": w.idHandlerGET, - "POST /events": w.eventsHandler, + "POST /events": w.eventsHandlerPOST, "GET /memory": w.memoryGET, @@ -1369,10 +1364,17 @@ func (w *worker) Handler() http.Handler { }) } -// Setup initializes the worker cache. +// Setup register event webhooks that enable the worker cache. func (w *worker) Setup(ctx context.Context, apiURL, apiPassword string) error { - webhookOpts := []webhooks.HeaderOption{webhooks.WithBasicAuth("", apiPassword)} - return w.cache.Initialize(ctx, apiURL, webhookOpts...) + go func() { + eventsURL := fmt.Sprintf("%s/events", apiURL) + webhookOpts := []webhooks.HeaderOption{webhooks.WithBasicAuth("", apiPassword)} + if err := w.eventSubscriber.Register(w.shutdownCtx, eventsURL, webhookOpts...); err != nil { + w.logger.Errorw("failed to register webhooks", zap.Error(err)) + } + }() + + return w.cache.Subscribe(w.eventSubscriber) } // Shutdown shuts down the worker. @@ -1386,7 +1388,9 @@ func (w *worker) Shutdown(ctx context.Context) error { // stop recorders w.contractSpendingRecorder.Stop(ctx) - return nil + + // shutdown the subscriber + return w.eventSubscriber.Shutdown(ctx) } func (w *worker) scanHost(ctx context.Context, timeout time.Duration, hostKey types.PublicKey, hostIP string) (rhpv2.HostSettings, rhpv3.HostPriceTable, time.Duration, error) {