From c97b092b875c7384e29c352a90129757b6513061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20D=C3=B6ll?= Date: Sat, 4 May 2024 08:25:13 +0000 Subject: [PATCH] fix: configure auth checker --- openapi.go | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/openapi.go b/openapi.go index 877cb45..eccdcc8 100644 --- a/openapi.go +++ b/openapi.go @@ -10,6 +10,44 @@ import ( middleware "github.com/oapi-codegen/fiber-middleware" ) +// OpenAPIAuthenticatorOpts are the OpenAPI authenticator options. +type OpenAPIAuthenticatorOpts struct { + PathParam string + Checker AuthzChecker +} + +// Conigure the OpenAPI authenticator. +func (o *OpenAPIAuthenticatorOpts) Conigure(opts ...OpenAPIAuthenticatorOpt) { + for _, opt := range opts { + opt(o) + } +} + +// OpenAPIAuthenticatorOpt is a function that sets an option on the OpenAPI authenticator. +type OpenAPIAuthenticatorOpt func(*OpenAPIAuthenticatorOpts) + +// OpenAPIAuthenticatorDefaultOpts are the default OpenAPI authenticator options. +func OpenAPIAuthenticatorDefaultOpts() OpenAPIAuthenticatorOpts { + return OpenAPIAuthenticatorOpts{ + PathParam: "teamId", + Checker: NewNoop(), + } +} + +// WithPathParam sets the path parameter. +func WithPathParam(param string) OpenAPIAuthenticatorOpt { + return func(opts *OpenAPIAuthenticatorOpts) { + opts.PathParam = param + } +} + +// WithChecker sets the authz checker. +func WithChecker(checker AuthzChecker) OpenAPIAuthenticatorOpt { + return func(opts *OpenAPIAuthenticatorOpts) { + opts.Checker = checker + } +} + // NewOpenAPIErrorHandler creates a new OpenAPI error handler. func NewOpenAPIErrorHandler() middleware.ErrorHandler { return func(c *fiber.Ctx, message string, statusCode int) { @@ -21,9 +59,13 @@ func NewOpenAPIErrorHandler() middleware.ErrorHandler { } // NewOpenAPIAuthenticator creates a new OpenAPI authenticator. -func NewOpenAPIAuthenticator() openapi3filter.AuthenticationFunc { +func NewOpenAPIAuthenticator(opts ...OpenAPIAuthenticatorOpt) openapi3filter.AuthenticationFunc { return func(ctx context.Context, input *openapi3filter.AuthenticationInput) error { + opt := OpenAPIAuthenticatorDefaultOpts() + opt.Conigure(opts...) + c := middleware.GetFiberContext(ctx) + obj := AuthzObject(c.Params(opt.PathParam, "")) key, err := GetAPIKeyFromRequest(input.RequestValidationInput.Request) if err != nil { @@ -35,6 +77,18 @@ func NewOpenAPIAuthenticator() openapi3filter.AuthenticationFunc { return fiber.NewError(fiber.StatusUnauthorized, "Invalid API key") } + allowed := len(input.Scopes) == 0 + if len(input.Scopes) > 0 { + allowed, err = opt.Checker.Allowed(ctx, AuthzPrincipal(key), obj, AuthzAction(input.Scopes[0])) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "Internal Server Error") + } + } + + if !allowed { + return fiber.NewError(fiber.StatusForbidden, "Forbidden") + } + // Create a new context with the API key. usrCtx := c.UserContext() authCtx := context.WithValue(usrCtx, authzAPIKey, key)