diff --git a/api/api.go b/api/api.go index 9997c714..f06be990 100644 --- a/api/api.go +++ b/api/api.go @@ -120,7 +120,7 @@ type ( RegisterWebHook(callbackURL string, scopes []string) (webhooks.WebHook, error) UpdateWebHook(id int64, callbackURL string, scopes []string) (webhooks.WebHook, error) RemoveWebHook(id int64) error - BroadcastEvent(event string, scope string, data any) error + BroadcastToWebhook(id int64, event, scope string, data interface{}) error } // A RHPSessionReporter reports on RHP session lifecycle events diff --git a/api/endpoints.go b/api/endpoints.go index 672eb6ea..2605522c 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -567,7 +567,12 @@ func (a *api) handlePUTWebhooks(c jape.Context) { } func (a *api) handlePOSTWebhooksTest(c jape.Context) { - if err := a.webhooks.BroadcastEvent("test", webhooks.ScopeTest, nil); err != nil { + var id int64 + if err := c.DecodeParam("id", &id); err != nil { + return + } + + if err := a.webhooks.BroadcastToWebhook(id, "test", webhooks.ScopeTest, nil); err != nil { c.Error(err, http.StatusInternalServerError) return } diff --git a/go.mod b/go.mod index 0c8d4dac..1e3632da 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.17 gitlab.com/NebulousLabs/bolt v1.4.4 gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe - go.sia.tech/core v0.1.12-0.20231209150840-62eed6d74fd4 + go.sia.tech/core v0.1.12-0.20231211182757-77190f04f90b go.sia.tech/jape v0.10.0 go.sia.tech/renterd v0.6.0 go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca diff --git a/go.sum b/go.sum index ef4ff7fb..eb98f946 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ gitlab.com/NebulousLabs/writeaheadlog v0.0.0-20200618142844-c59a90f49130/go.mod go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.sia.tech/core v0.1.12-0.20231209150840-62eed6d74fd4 h1:fKgxWoT/Mo4rPWRvymyNyh6YEBKU22usu9hrzWLK7Tc= go.sia.tech/core v0.1.12-0.20231209150840-62eed6d74fd4/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= +go.sia.tech/core v0.1.12-0.20231211182757-77190f04f90b h1:xJSxYN2kZD3NAijHIwjXhG5+7GoPyjDNIJPEoD3b72g= +go.sia.tech/core v0.1.12-0.20231211182757-77190f04f90b/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= go.sia.tech/jape v0.10.0 h1:wsIURirNV29fvqxhvvbd0yhKh+9JeNZvz4haJUL/+yI= go.sia.tech/jape v0.10.0/go.mod h1:4QqmBB+t3W7cNplXPj++ZqpoUb2PeiS66RLpXmEGap4= go.sia.tech/mux v1.2.0 h1:ofa1Us9mdymBbGMY2XH/lSpY8itFsKIo/Aq8zwe+GHU= diff --git a/webhooks/webhooks.go b/webhooks/webhooks.go index 531ded93..cddb6b47 100644 --- a/webhooks/webhooks.go +++ b/webhooks/webhooks.go @@ -255,6 +255,48 @@ func sendEventData(ctx context.Context, hook WebHook, buf []byte) error { return nil } +// BroadcastToWebhook sends an event to a specific WebHook subscriber. +func (m *Manager) BroadcastToWebhook(hookID int64, event string, scope string, data any) error { + done, err := m.tg.Add() + if err != nil { + return err + } + defer done() + + uid := UID(frand.Bytes(32)) + e := Event{ + ID: uid, + Event: event, + Scope: scope, + Data: data, + } + + buf, err := json.Marshal(e) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + hook, ok := m.hooks[hookID] + if !ok { + return fmt.Errorf("webhook not found") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + log := m.log.With(zap.Int64("hook", hook.ID), zap.String("url", hook.CallbackURL), zap.String("scope", scope), zap.String("event", event)) + + start := time.Now() + if err := sendEventData(ctx, hook, buf); err != nil { + return fmt.Errorf("failed to send webhook event: %w", err) + } + log.Debug("sent webhook event", zap.Duration("elapsed", time.Since(start))) + return nil +} + // BroadcastEvent sends an event to all registered WebHooks that match the // event's scope. func (m *Manager) BroadcastEvent(event string, scope string, data any) error { diff --git a/webhooks/webhooks_test.go b/webhooks/webhooks_test.go index f30ca37b..32e37721 100644 --- a/webhooks/webhooks_test.go +++ b/webhooks/webhooks_test.go @@ -20,34 +20,23 @@ type jsonEvent struct { Event string `json:"event"` Scope string `json:"scope"` Data json.RawMessage `json:"data"` + Error error `json:"-"` } -func TestWebHooks(t *testing.T) { - log := zaptest.NewLogger(t) - - db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - wr, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - +func registerWebhook(t testing.TB, wr *webhooks.Manager, scopes []string) (webhooks.WebHook, <-chan jsonEvent, error) { // create a listener for the webhook l, err := net.Listen("tcp", ":0") if err != nil { - t.Fatal(err) + return webhooks.WebHook{}, nil, fmt.Errorf("failed to create listener: %w", err) } - defer l.Close() + t.Cleanup(func() { + l.Close() + }) // add a webhook - scopes := []string{"tld", "scope/subscope"} hook, err := wr.RegisterWebHook("http://"+l.Addr().String(), scopes) if err != nil { - t.Fatal(err) + return webhooks.WebHook{}, nil, fmt.Errorf("failed to register webhook: %w", err) } // create an http server to listen for the webhook @@ -56,25 +45,51 @@ func TestWebHooks(t *testing.T) { http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, password, ok := r.BasicAuth() if !ok || password != hook.SecretKey { - t.Error("bad auth") + w.WriteHeader(http.StatusUnauthorized) + recv <- jsonEvent{Error: errors.New("bad auth")} + return } // handle the webhook var event jsonEvent if err := json.NewDecoder(r.Body).Decode(&event); err != nil { - t.Error(err) + w.WriteHeader(http.StatusBadRequest) + recv <- jsonEvent{Error: fmt.Errorf("failed to decode webhook: %w", err)} + return } w.WriteHeader(http.StatusNoContent) recv <- event })) }() + return hook, recv, nil +} + +func TestWebHooks(t *testing.T) { + log := zaptest.NewLogger(t) + + db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + wr, err := webhooks.NewManager(db, log.Named("webhooks")) + if err != nil { + t.Fatal(err) + } + + // add a webhook + hook, hook1Ch, err := registerWebhook(t, wr, []string{"tld", "scope/subscope"}) + if err != nil { + t.Fatal(err) + } checkEvent := func(event, scope, data string) error { select { case <-time.After(time.Second): return errors.New("timed out") - case ev := <-recv: + case ev := <-hook1Ch: switch { case ev.Event != event: return fmt.Errorf("expected event %q, got %q", event, ev.Event) @@ -111,7 +126,7 @@ func TestWebHooks(t *testing.T) { } // update the webhook to have the "all scope" - hook, err = wr.UpdateWebHook(hook.ID, "http://"+l.Addr().String(), []string{"all"}) + hook, err = wr.UpdateWebHook(hook.ID, hook.CallbackURL, []string{"all"}) if err != nil { t.Fatal(err) } else if hooks, err := wr.WebHooks(); err != nil { @@ -147,3 +162,60 @@ func TestWebHooks(t *testing.T) { } } } + +func TestBroadcastToWebhook(t *testing.T) { + log := zaptest.NewLogger(t) + + db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + wr, err := webhooks.NewManager(db, log.Named("webhooks")) + if err != nil { + t.Fatal(err) + } + + checkEvent := func(recv <-chan jsonEvent, event, scope, data string) error { + select { + case <-time.After(time.Second): + return errors.New("timed out") + case ev := <-recv: + switch { + case ev.Event != event: + return fmt.Errorf("expected event %q, got %q", event, ev.Event) + case ev.Scope != scope: + return fmt.Errorf("expected scope %q, got %q", scope, ev.Scope) + case string(ev.Data) != data: + return fmt.Errorf("expected data %q, got %q", data, ev.Data) + } + } + return nil + } + + hook1, hook1Ch, err := registerWebhook(t, wr, []string{"all"}) + if err != nil { + t.Fatal(err) + } + + _, hook2Ch, err := registerWebhook(t, wr, []string{"all"}) + if err != nil { + t.Fatal(err) + } + + // broadcast to hook1 + if err := wr.BroadcastToWebhook(hook1.ID, "test", "test", "hello, world!"); err != nil { + t.Fatal(err) + } + + // check that hook 2 did not receive the event + if err := checkEvent(hook2Ch, "test", "test", `"hello, world!"`); err == nil { + t.Fatal("expected no event") + } + + // check that hook 1 did receive the event + if err := checkEvent(hook1Ch, "test", "test", `"hello, world!"`); err != nil { + t.Fatal(err) + } +}