diff --git a/api/v1/config.go b/api/v1/config.go index 83a6b4b..8262454 100644 --- a/api/v1/config.go +++ b/api/v1/config.go @@ -27,6 +27,7 @@ type AzureConfig struct { } type ClerkConfig struct { - SecretKey string `json:"secretKey" yaml:"secretKey"` - JWKSURL string `json:"jwks_url" yaml:"jwks_url"` + SecretKey string `json:"secretKey" yaml:"secretKey"` + JWKSURL string `json:"jwks_url" yaml:"jwks_url"` + WebhookSecret string `json:"webhook_secret" yaml:"webhook_secret"` } diff --git a/cmd/serve.go b/cmd/serve.go index cd52529..783b03b 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -49,6 +49,12 @@ func serve(configFile string) { echopprof.Wrap(e) } + var err error + tenant.ClerkTenantWebhook, err = tenant.NewWebhook(config.Config.Clerk.WebhookSecret) + if err != nil { + log.Fatalf("Error setting up webhook: %v", err) + } + e.GET("/health", func(c echo.Context) error { return c.JSON(200, map[string]string{"message": "ok"}) }) e.POST("/tenant", tenant.CreateTenant) diff --git a/examples/configs/config.yaml b/examples/configs/config.yaml index 94eb202..1a42118 100644 --- a/examples/configs/config.yaml +++ b/examples/configs/config.yaml @@ -7,6 +7,7 @@ azure: vault_uri: clerk: jwks_url: + webhook_secret: git: repository: https://github.com/flanksource/sass-dev user: flankbot diff --git a/pkg/tenant/controllers.go b/pkg/tenant/controllers.go index 2b18044..864aa6a 100644 --- a/pkg/tenant/controllers.go +++ b/pkg/tenant/controllers.go @@ -1,29 +1,42 @@ package tenant import ( + "encoding/json" "errors" "fmt" + "io" "net/http" - "github.com/flanksource/commons/logger" v1 "github.com/flanksource/tenant-controller/api/v1" "github.com/flanksource/tenant-controller/pkg/git" "github.com/flanksource/tenant-controller/pkg/secrets" "github.com/labstack/echo/v4" ) +var ClerkTenantWebhook *Webhook + func CreateTenant(c echo.Context) error { if c.Request().Body == nil { return errorResonse(c, errors.New("missing request body"), http.StatusBadRequest) } defer c.Request().Body.Close() + body, err := io.ReadAll(c.Request().Body) + if err != nil { + errorResonse(c, err, http.StatusBadRequest) + } + var reqBody v1.TenantRequestBody - if err := c.Bind(&reqBody); err != nil { - logger.Infof("Broken %v", err) + if err := json.Unmarshal(body, &reqBody); err != nil { return errorResonse(c, err, http.StatusBadRequest) } + // Ignoring timestamp since the tolerance is 5mins + // How to replay older message for whom tenant creation failed ? + if err := ClerkTenantWebhook.VerifyIgnoringTimestamp(body, c.Request().Header); err != nil { + return errorResonse(c, fmt.Errorf("webhook verification failed: %w", err), http.StatusBadRequest) + } + tenant, err := NewTenant(reqBody) if err != nil { return errorResonse(c, err, http.StatusBadRequest) diff --git a/pkg/tenant/webhook_verify.go b/pkg/tenant/webhook_verify.go new file mode 100644 index 0000000..ad92edb --- /dev/null +++ b/pkg/tenant/webhook_verify.go @@ -0,0 +1,151 @@ +package tenant + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +var base64enc = base64.StdEncoding + +type Webhook struct { + key []byte +} + +const webhookSecretPrefix = "whsec_" + +var tolerance time.Duration = 5 * time.Minute + +var ( + errRequiredHeaders = fmt.Errorf("Missing Required Headers") + errInvalidHeaders = fmt.Errorf("Invalid Signature Headers") + errNoMatchingSignature = fmt.Errorf("No matching signature found") + errMessageTooOld = fmt.Errorf("Message timestamp too old") + errMessageTooNew = fmt.Errorf("Message timestamp too new") +) + +func NewWebhook(secret string) (*Webhook, error) { + key, err := base64enc.DecodeString(strings.TrimPrefix(secret, webhookSecretPrefix)) + if err != nil { + return nil, err + } + return &Webhook{ + key: key, + }, nil +} + +func NewWebhookRaw(secret []byte) (*Webhook, error) { + return &Webhook{ + key: secret, + }, nil +} + +// Verify validates the payload against the svix signature headers +// using the webhooks signing secret. +// +// Returns an error if the body or headers are missing/unreadable +// or if the signature doesn't match. +func (wh *Webhook) Verify(payload []byte, headers http.Header) error { + return wh.verify(payload, headers, true) +} + +// VerifyIgnoringTimestamp validates the payload against the svix signature headers +// using the webhooks signing secret. +// +// Returns an error if the body or headers are missing/unreadable +// or if the signature doesn't match. +// +// WARNING: This function does not check the signature's timestamp. +// We recommend using the `Verify` function instead. +func (wh *Webhook) VerifyIgnoringTimestamp(payload []byte, headers http.Header) error { + return wh.verify(payload, headers, false) +} + +func (wh *Webhook) verify(payload []byte, headers http.Header, enforceTolerance bool) error { + msgId := headers.Get("svix-id") + msgSignature := headers.Get("svix-signature") + msgTimestamp := headers.Get("svix-timestamp") + + if msgId == "" || msgSignature == "" || msgTimestamp == "" { + msgId = headers.Get("webhook-id") + msgSignature = headers.Get("webhook-signature") + msgTimestamp = headers.Get("webhook-timestamp") + if msgId == "" || msgSignature == "" || msgTimestamp == "" { + return errRequiredHeaders + } + } + + timestamp, err := parseTimestampHeader(msgTimestamp) + if err != nil { + return err + } + + if enforceTolerance { + if err := verifyTimestamp(timestamp); err != nil { + return err + } + } + + computedSignature, err := wh.Sign(msgId, timestamp, payload) + if err != nil { + return err + } + expectedSignature := []byte(strings.Split(computedSignature, ",")[1]) + + passedSignatures := strings.Split(msgSignature, " ") + for _, versionedSignature := range passedSignatures { + sigParts := strings.Split(versionedSignature, ",") + if len(sigParts) < 2 { + continue + } + version := sigParts[0] + signature := []byte(sigParts[1]) + + if version != "v1" { + continue + } + + if hmac.Equal(signature, expectedSignature) { + return nil + } + } + return errNoMatchingSignature +} + +func (wh *Webhook) Sign(msgId string, timestamp time.Time, payload []byte) (string, error) { + toSign := fmt.Sprintf("%s.%d.%s", msgId, timestamp.Unix(), payload) + + h := hmac.New(sha256.New, wh.key) + h.Write([]byte(toSign)) + sig := make([]byte, base64enc.EncodedLen(h.Size())) + base64enc.Encode(sig, h.Sum(nil)) + return fmt.Sprintf("v1,%s", sig), nil + +} + +func parseTimestampHeader(timestampHeader string) (time.Time, error) { + timeInt, err := strconv.ParseInt(timestampHeader, 10, 64) + if err != nil { + return time.Time{}, errInvalidHeaders + } + timestamp := time.Unix(timeInt, 0) + return timestamp, nil +} + +func verifyTimestamp(timestamp time.Time) error { + now := time.Now() + + if now.Sub(timestamp) > tolerance { + return errMessageTooOld + } + if timestamp.Unix() > now.Add(tolerance).Unix() { + return errMessageTooNew + } + + return nil +}