diff --git a/cmd/root.go b/cmd/root.go index 3809c66c..7153000d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -307,14 +307,48 @@ func ensureToken(ctx context.Context, requiredScopes []string) (context.Context, return ctx, fmt.Errorf("error extracting claims from token: %w", err) } + ok, missing := HasScopesFlexible(claims, requiredScopes) + + if !ok { + return ctx, fmt.Errorf("authenticated successfully, but you don't have the required permission: '%v'", missing) + } + + // Add the token to the context + return context.WithValue(ctx, sdp.UserTokenContextKey{}, accessToken), nil +} + +// Returns whether a set of claims has all of the required scopes. It also +// accounts for when a user has write access but required read access, they +// aren't the same but the user will have access anyway so this will pass +// +// Returns a bool and the missing permission as a string of any +func HasScopesFlexible(claims *sdp.CustomClaims, requiredScopes []string) (bool, string) { + if claims == nil { + return false, "" + } + for _, scope := range requiredScopes { if !claims.HasScope(scope) { - return ctx, fmt.Errorf("authenticated successfully, but you don't have the required permission: '%v'", scope) + // If they don't have the *exact* scope, check to see if they have + // write access to the same service + sections := strings.Split(scope, ":") + var hasWriteInstead bool + + if len(sections) == 2 { + service, action := sections[0], sections[1] + + if action == "read" { + hasWriteInstead = claims.HasScope(fmt.Sprintf("%v:write", service)) + } + } + + if !hasWriteInstead { + return false, scope + } } } - // Add the token to the context - return context.WithValue(ctx, sdp.UserTokenContextKey{}, accessToken), nil + return true, "" } // getChangeUuid returns the UUID of a change, as selected by --uuid or --change, or a state with the specified status and having --ticket-link diff --git a/cmd/root_test.go b/cmd/root_test.go index f3256243..51b3ccc5 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -2,6 +2,8 @@ package cmd import ( "testing" + + "github.com/overmindtech/sdp-go" ) func TestParseChangeUrl(t *testing.T) { @@ -24,3 +26,45 @@ func TestParseChangeUrl(t *testing.T) { } } } + +func TestHasScopesFlexible(t *testing.T) { + claims := &sdp.CustomClaims{ + Scope: "changes:read users:write", + AccountName: "test", + } + + tests := []struct { + Name string + RequiredScopes []string + ShouldPass bool + }{ + { + Name: "Same scope", + RequiredScopes: []string{"changes:read"}, + ShouldPass: true, + }, + { + Name: "Multiple scopes", + RequiredScopes: []string{"changes:read", "users:write"}, + ShouldPass: true, + }, + { + Name: "Missing scope", + RequiredScopes: []string{"changes:read", "users:write", "colours:create"}, + ShouldPass: false, + }, + { + Name: "Write instead of read", + RequiredScopes: []string{"users:read"}, + ShouldPass: true, + }, + } + + for _, tc := range tests { + t.Run(tc.Name, func(t *testing.T) { + if pass, _ := HasScopesFlexible(claims, tc.RequiredScopes); pass != tc.ShouldPass { + t.Fatalf("expected: %v, got: %v", tc.ShouldPass, !tc.ShouldPass) + } + }) + } +}