From 0b14c802fb34fef037151749cfb741d5cf76feeb Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 12 Dec 2023 08:09:51 -0800 Subject: [PATCH] webhooks: add broadcast to hook --- webhooks/webhooks.go | 42 ++++++++++++++ webhooks/webhooks_test.go | 116 ++++++++++++++++++++++++++++++-------- 2 files changed, 136 insertions(+), 22 deletions(-) diff --git a/webhooks/webhooks.go b/webhooks/webhooks.go index 531ded93..cddb6b47 100644 --- a/webhooks/webhooks.go +++ b/webhooks/webhooks.go @@ -255,6 +255,48 @@ func sendEventData(ctx context.Context, hook WebHook, buf []byte) error { return nil } +// BroadcastToWebhook sends an event to a specific WebHook subscriber. +func (m *Manager) BroadcastToWebhook(hookID int64, event string, scope string, data any) error { + done, err := m.tg.Add() + if err != nil { + return err + } + defer done() + + uid := UID(frand.Bytes(32)) + e := Event{ + ID: uid, + Event: event, + Scope: scope, + Data: data, + } + + buf, err := json.Marshal(e) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + hook, ok := m.hooks[hookID] + if !ok { + return fmt.Errorf("webhook not found") + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + log := m.log.With(zap.Int64("hook", hook.ID), zap.String("url", hook.CallbackURL), zap.String("scope", scope), zap.String("event", event)) + + start := time.Now() + if err := sendEventData(ctx, hook, buf); err != nil { + return fmt.Errorf("failed to send webhook event: %w", err) + } + log.Debug("sent webhook event", zap.Duration("elapsed", time.Since(start))) + return nil +} + // BroadcastEvent sends an event to all registered WebHooks that match the // event's scope. func (m *Manager) BroadcastEvent(event string, scope string, data any) error { diff --git a/webhooks/webhooks_test.go b/webhooks/webhooks_test.go index f30ca37b..32e37721 100644 --- a/webhooks/webhooks_test.go +++ b/webhooks/webhooks_test.go @@ -20,34 +20,23 @@ type jsonEvent struct { Event string `json:"event"` Scope string `json:"scope"` Data json.RawMessage `json:"data"` + Error error `json:"-"` } -func TestWebHooks(t *testing.T) { - log := zaptest.NewLogger(t) - - db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - wr, err := webhooks.NewManager(db, log.Named("webhooks")) - if err != nil { - t.Fatal(err) - } - +func registerWebhook(t testing.TB, wr *webhooks.Manager, scopes []string) (webhooks.WebHook, <-chan jsonEvent, error) { // create a listener for the webhook l, err := net.Listen("tcp", ":0") if err != nil { - t.Fatal(err) + return webhooks.WebHook{}, nil, fmt.Errorf("failed to create listener: %w", err) } - defer l.Close() + t.Cleanup(func() { + l.Close() + }) // add a webhook - scopes := []string{"tld", "scope/subscope"} hook, err := wr.RegisterWebHook("http://"+l.Addr().String(), scopes) if err != nil { - t.Fatal(err) + return webhooks.WebHook{}, nil, fmt.Errorf("failed to register webhook: %w", err) } // create an http server to listen for the webhook @@ -56,25 +45,51 @@ func TestWebHooks(t *testing.T) { http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, password, ok := r.BasicAuth() if !ok || password != hook.SecretKey { - t.Error("bad auth") + w.WriteHeader(http.StatusUnauthorized) + recv <- jsonEvent{Error: errors.New("bad auth")} + return } // handle the webhook var event jsonEvent if err := json.NewDecoder(r.Body).Decode(&event); err != nil { - t.Error(err) + w.WriteHeader(http.StatusBadRequest) + recv <- jsonEvent{Error: fmt.Errorf("failed to decode webhook: %w", err)} + return } w.WriteHeader(http.StatusNoContent) recv <- event })) }() + return hook, recv, nil +} + +func TestWebHooks(t *testing.T) { + log := zaptest.NewLogger(t) + + db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + wr, err := webhooks.NewManager(db, log.Named("webhooks")) + if err != nil { + t.Fatal(err) + } + + // add a webhook + hook, hook1Ch, err := registerWebhook(t, wr, []string{"tld", "scope/subscope"}) + if err != nil { + t.Fatal(err) + } checkEvent := func(event, scope, data string) error { select { case <-time.After(time.Second): return errors.New("timed out") - case ev := <-recv: + case ev := <-hook1Ch: switch { case ev.Event != event: return fmt.Errorf("expected event %q, got %q", event, ev.Event) @@ -111,7 +126,7 @@ func TestWebHooks(t *testing.T) { } // update the webhook to have the "all scope" - hook, err = wr.UpdateWebHook(hook.ID, "http://"+l.Addr().String(), []string{"all"}) + hook, err = wr.UpdateWebHook(hook.ID, hook.CallbackURL, []string{"all"}) if err != nil { t.Fatal(err) } else if hooks, err := wr.WebHooks(); err != nil { @@ -147,3 +162,60 @@ func TestWebHooks(t *testing.T) { } } } + +func TestBroadcastToWebhook(t *testing.T) { + log := zaptest.NewLogger(t) + + db, err := sqlite.OpenDatabase(filepath.Join(t.TempDir(), "hostd.db"), log.Named("sqlite")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + wr, err := webhooks.NewManager(db, log.Named("webhooks")) + if err != nil { + t.Fatal(err) + } + + checkEvent := func(recv <-chan jsonEvent, event, scope, data string) error { + select { + case <-time.After(time.Second): + return errors.New("timed out") + case ev := <-recv: + switch { + case ev.Event != event: + return fmt.Errorf("expected event %q, got %q", event, ev.Event) + case ev.Scope != scope: + return fmt.Errorf("expected scope %q, got %q", scope, ev.Scope) + case string(ev.Data) != data: + return fmt.Errorf("expected data %q, got %q", data, ev.Data) + } + } + return nil + } + + hook1, hook1Ch, err := registerWebhook(t, wr, []string{"all"}) + if err != nil { + t.Fatal(err) + } + + _, hook2Ch, err := registerWebhook(t, wr, []string{"all"}) + if err != nil { + t.Fatal(err) + } + + // broadcast to hook1 + if err := wr.BroadcastToWebhook(hook1.ID, "test", "test", "hello, world!"); err != nil { + t.Fatal(err) + } + + // check that hook 2 did not receive the event + if err := checkEvent(hook2Ch, "test", "test", `"hello, world!"`); err == nil { + t.Fatal("expected no event") + } + + // check that hook 1 did receive the event + if err := checkEvent(hook1Ch, "test", "test", `"hello, world!"`); err != nil { + t.Fatal(err) + } +}