diff --git a/middleware.go b/middleware.go index de3e3c1..b75ed6d 100644 --- a/middleware.go +++ b/middleware.go @@ -29,9 +29,17 @@ type AuthBypassedContextKey struct{} // from the JWT type CustomClaimsContextKey struct{} +// AccountNameContextKey is the key that is used to store the currently acting +// account name +type AccountNameContextKey struct{} + // UserTokenContextKey is the key that is used to store the full JWT token of the user type UserTokenContextKey struct{} +// CurrentSubjectContextKey is the key that is used to store the current subject attribute. +// This will be the auth0 `user_id` from the tokens `sub` claim. +type CurrentSubjectContextKey struct{} + // AuthConfig Configuration for the auth middleware type AuthConfig struct { // Bypasses all auth checks, meaning that HasScopes() will always return @@ -178,6 +186,17 @@ func AddBypassAuthConfig(ctx context.Context) context.Context { return context.WithValue(ctx, AuthBypassedContextKey{}, true) } +// OverrideAuthContext overrides the authentication data and token stored in the context. +// This is mostly useful for testing or delegating access locally into a protected API. +func OverrideAuthContext(ctx context.Context, claims *validator.ValidatedClaims) context.Context { + customClaims := claims.CustomClaims.(*CustomClaims) + ctx = context.WithValue(ctx, jwtmiddleware.ContextKey{}, claims) + ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims) + ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject) + ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName) + return ctx +} + // OverrideCustomClaims Overrides the custom claims in the context that have // been set at CustomClaimsContextKey func OverrideCustomClaims(ctx context.Context, scope *string, account *string) context.Context { @@ -276,7 +295,7 @@ func ensureValidTokenHandler(next http.Handler) http.Handler { return middleware.CheckJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // extract account name and setup otel attributes after the JWT was validated, but before the actual handler runs claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims) - customClaims := claims.CustomClaims.(*CustomClaims) + token, err := tokenExtractor(r) // we should never hit this as the middleware wouldn't call the handler if err != nil { @@ -286,22 +305,32 @@ func ensureValidTokenHandler(next http.Handler) http.Handler { return } - r = r.Clone(context.WithValue(r.Context(), UserTokenContextKey{}, token)) - - if customClaims != nil { - r = r.Clone(context.WithValue(r.Context(), CustomClaimsContextKey{}, customClaims)) - - trace.SpanFromContext(r.Context()).SetAttributes( - attribute.String("om.auth.scopes", customClaims.Scope), - attribute.Int64("om.auth.expiry", claims.RegisteredClaims.Expiry), - attribute.String("om.auth.accountName", customClaims.AccountName), - ) - - next.ServeHTTP(w, r) - } else { + customClaims := claims.CustomClaims.(*CustomClaims) + if customClaims == nil { errorHandler(w, r, fmt.Errorf("couldn't get claims from: %v", claims)) return } + + ctx := r.Context() + + // note that the values are looked up in last-in-first-out order, so + // there is an absolutely minor perf optimisation to have the context + // values set in ascending order of access frequency. + ctx = context.WithValue(ctx, UserTokenContextKey{}, token) + ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims) + ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject) + ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName) + + trace.SpanFromContext(ctx).SetAttributes( + attribute.String("om.auth.accountName", customClaims.AccountName), + attribute.Int64("om.auth.expiry", claims.RegisteredClaims.Expiry), + attribute.String("om.auth.scopes", customClaims.Scope), + attribute.String("om.auth.subject", claims.RegisteredClaims.Subject), + ) + + r = r.Clone(ctx) + + next.ServeHTTP(w, r) })) }