Skip to content

Commit

Permalink
Merge pull request #152 from overmindtech/account_names
Browse files Browse the repository at this point in the history
Extract and cache the subject and account name into separate context values
  • Loading branch information
DavidS-ovm authored Oct 2, 2023
2 parents f8f3363 + 5bd78dd commit 728d288
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,17 @@ type AuthBypassedContextKey struct{}
// from the JWT
type CustomClaimsContextKey struct{}

// AccountNameContextKey is the key that is used to store the currently acting
// account name
type AccountNameContextKey struct{}

// UserTokenContextKey is the key that is used to store the full JWT token of the user
type UserTokenContextKey struct{}

// CurrentSubjectContextKey is the key that is used to store the current subject attribute.
// This will be the auth0 `user_id` from the tokens `sub` claim.
type CurrentSubjectContextKey struct{}

// AuthConfig Configuration for the auth middleware
type AuthConfig struct {
// Bypasses all auth checks, meaning that HasScopes() will always return
Expand Down Expand Up @@ -178,6 +186,17 @@ func AddBypassAuthConfig(ctx context.Context) context.Context {
return context.WithValue(ctx, AuthBypassedContextKey{}, true)
}

// OverrideAuthContext overrides the authentication data and token stored in the context.
// This is mostly useful for testing or delegating access locally into a protected API.
func OverrideAuthContext(ctx context.Context, claims *validator.ValidatedClaims) context.Context {
customClaims := claims.CustomClaims.(*CustomClaims)
ctx = context.WithValue(ctx, jwtmiddleware.ContextKey{}, claims)
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims)
ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject)
ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName)
return ctx
}

// OverrideCustomClaims Overrides the custom claims in the context that have
// been set at CustomClaimsContextKey
func OverrideCustomClaims(ctx context.Context, scope *string, account *string) context.Context {
Expand Down Expand Up @@ -276,7 +295,7 @@ func ensureValidTokenHandler(next http.Handler) http.Handler {
return middleware.CheckJWT(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// extract account name and setup otel attributes after the JWT was validated, but before the actual handler runs
claims := r.Context().Value(jwtmiddleware.ContextKey{}).(*validator.ValidatedClaims)
customClaims := claims.CustomClaims.(*CustomClaims)

token, err := tokenExtractor(r)
// we should never hit this as the middleware wouldn't call the handler
if err != nil {
Expand All @@ -286,22 +305,32 @@ func ensureValidTokenHandler(next http.Handler) http.Handler {
return
}

r = r.Clone(context.WithValue(r.Context(), UserTokenContextKey{}, token))

if customClaims != nil {
r = r.Clone(context.WithValue(r.Context(), CustomClaimsContextKey{}, customClaims))

trace.SpanFromContext(r.Context()).SetAttributes(
attribute.String("om.auth.scopes", customClaims.Scope),
attribute.Int64("om.auth.expiry", claims.RegisteredClaims.Expiry),
attribute.String("om.auth.accountName", customClaims.AccountName),
)

next.ServeHTTP(w, r)
} else {
customClaims := claims.CustomClaims.(*CustomClaims)
if customClaims == nil {
errorHandler(w, r, fmt.Errorf("couldn't get claims from: %v", claims))
return
}

ctx := r.Context()

// note that the values are looked up in last-in-first-out order, so
// there is an absolutely minor perf optimisation to have the context
// values set in ascending order of access frequency.
ctx = context.WithValue(ctx, UserTokenContextKey{}, token)
ctx = context.WithValue(ctx, CustomClaimsContextKey{}, customClaims)
ctx = context.WithValue(ctx, CurrentSubjectContextKey{}, claims.RegisteredClaims.Subject)
ctx = context.WithValue(ctx, AccountNameContextKey{}, customClaims.AccountName)

trace.SpanFromContext(ctx).SetAttributes(
attribute.String("om.auth.accountName", customClaims.AccountName),
attribute.Int64("om.auth.expiry", claims.RegisteredClaims.Expiry),
attribute.String("om.auth.scopes", customClaims.Scope),
attribute.String("om.auth.subject", claims.RegisteredClaims.Subject),
)

r = r.Clone(ctx)

next.ServeHTTP(w, r)
}))
}

Expand Down

0 comments on commit 728d288

Please sign in to comment.