diff --git a/cmd/root.go b/cmd/root.go index 84836654..3809c66c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -146,31 +146,54 @@ func readLocalToken(homeDir string, expectedScopes []string) (string, []string, return token.AccessToken, currentScopes, nil } -// ensureToken -func ensureToken(ctx context.Context, requiredScopes []string) (context.Context, error) { - // get a token from the api key if present - if viper.GetString("api-key") != "" { - log.WithContext(ctx).Debug("using provided token for authentication") - apiKey := viper.GetString("api-key") - if strings.HasPrefix(apiKey, "ovm_api_") { - // exchange api token for JWT - client := UnauthenticatedApiKeyClient(ctx) - resp, err := client.ExchangeKeyForToken(ctx, &connect.Request[sdp.ExchangeKeyForTokenRequest]{ - Msg: &sdp.ExchangeKeyForTokenRequest{ - ApiKey: apiKey, - }, - }) - if err != nil { - return ctx, fmt.Errorf("error authenticating the API token: %w", err) - } - log.WithContext(ctx).Debug("successfully authenticated") - apiKey = resp.Msg.GetAccessToken() - } else { - return ctx, errors.New("OVM_API_KEY does not match pattern 'ovm_api_*'") +// Check whether or not a token has all of the required scopes. Returns a +// boolean and an error which will be populated if we couldn't read the token +func tokenHasAllScopes(token string, requiredScopes []string) (bool, error) { + claims, err := extractClaims(token) + + if err != nil { + return false, fmt.Errorf("error extracting claims from token: %w", err) + } + + // Check that the token has the right scopes + for _, scope := range requiredScopes { + if !claims.HasScope(scope) { + return false, nil + } + } + + return true, nil +} + +// Gets a token using an API key +func getAPIKeyToken(ctx context.Context, apiKey string) (string, error) { + log.WithContext(ctx).Debug("using provided token for authentication") + + var accessToken string + + if strings.HasPrefix(apiKey, "ovm_api_") { + // exchange api token for JWT + client := UnauthenticatedApiKeyClient(ctx) + resp, err := client.ExchangeKeyForToken(ctx, &connect.Request[sdp.ExchangeKeyForTokenRequest]{ + Msg: &sdp.ExchangeKeyForTokenRequest{ + ApiKey: apiKey, + }, + }) + if err != nil { + return "", fmt.Errorf("error authenticating the API token: %w", err) } - return context.WithValue(ctx, sdp.UserTokenContextKey{}, apiKey), nil + log.WithContext(ctx).Debug("successfully authenticated") + accessToken = resp.Msg.GetAccessToken() + } else { + return "", errors.New("OVM_API_KEY does not match pattern 'ovm_api_*'") } + return accessToken, nil +} + +// Gets a token from Oauth with the required scopes. This method will also cache +// that token locally for use later, and will use the cached token if possible +func getOauthToken(ctx context.Context, requiredScopes []string) (string, error) { var localScopes []string // Check for a locally saved token in ~/.overmind @@ -182,92 +205,116 @@ func ensureToken(ctx context.Context, requiredScopes []string) (context.Context, if err != nil { log.WithContext(ctx).Debugf("Error reading local token, ignoring: %v", err) } else { - return context.WithValue(ctx, sdp.UserTokenContextKey{}, localToken), nil + // If we already have the right scopes, return the token + return localToken, nil } } + // If we need to get a new token, request the required scopes on top of + // whatever ones the current local, valid token has so that we don't + // keep replacing it + // Check to see if the URL is secure appurl := viper.GetString("url") parsed, err := url.Parse(appurl) if err != nil { log.WithContext(ctx).WithError(err).Error("Failed to parse --url") - return ctx, fmt.Errorf("error parsing --url: %w", err) - } - - if parsed.Scheme == "wss" || parsed.Scheme == "https" || parsed.Hostname() == "localhost" { - // If we need to get a new token, request the required scopes on top of - // whatever ones the current local, valid token has so that we don't - // keep replacing it - requestScopes := append(requiredScopes, localScopes...) - - // Authenticate using the oauth device authorization flow - config := oauth2.Config{ - ClientID: viper.GetString("cli-auth0-client-id"), - Endpoint: oauth2.Endpoint{ - AuthURL: fmt.Sprintf("https://%v/authorize", viper.GetString("cli-auth0-domain")), - TokenURL: fmt.Sprintf("https://%v/oauth/token", viper.GetString("cli-auth0-domain")), - DeviceAuthURL: fmt.Sprintf("https://%v/oauth/device/code", viper.GetString("cli-auth0-domain")), - }, - Scopes: requestScopes, - } + return "", fmt.Errorf("error parsing --url: %w", err) + } - deviceCode, err := config.DeviceAuth(ctx, oauth2.SetAuthURLParam("audience", "https://api.overmind.tech")) - if err != nil { - return ctx, fmt.Errorf("error getting device code: %w", err) - } + if !(parsed.Scheme == "wss" || parsed.Scheme == "https" || parsed.Hostname() == "localhost") { + return "", fmt.Errorf("target URL (%v) is insecure", parsed) + } + // If we need to get a new token, request the required scopes on top of + // whatever ones the current local, valid token has so that we don't + // keep replacing it + requestScopes := append(requiredScopes, localScopes...) + + // Authenticate using the oauth device authorization flow + config := oauth2.Config{ + ClientID: viper.GetString("cli-auth0-client-id"), + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("https://%v/authorize", viper.GetString("cli-auth0-domain")), + TokenURL: fmt.Sprintf("https://%v/oauth/token", viper.GetString("cli-auth0-domain")), + DeviceAuthURL: fmt.Sprintf("https://%v/oauth/device/code", viper.GetString("cli-auth0-domain")), + }, + Scopes: requestScopes, + } + + deviceCode, err := config.DeviceAuth(ctx, oauth2.SetAuthURLParam("audience", "https://api.overmind.tech")) + if err != nil { + return "", fmt.Errorf("error getting device code: %w", err) + } + + fmt.Printf("Go to %v and verify this code: %v\n", deviceCode.VerificationURIComplete, deviceCode.UserCode) + + token, err := config.DeviceAccessToken(ctx, deviceCode) + if err != nil { + fmt.Printf(": %v\n", err) + return "", fmt.Errorf("Error exchanging Device Code for for access token: %w", err) + } - fmt.Printf("Go to %v and verify this code: %v\n", deviceCode.VerificationURIComplete, deviceCode.UserCode) + log.WithContext(ctx).Info("Authenticated successfully ✅") - token, err := config.DeviceAccessToken(ctx, deviceCode) + // Save the token locally + if home, err := os.UserHomeDir(); err == nil { + // Create the directory if it doesn't exist + err = os.MkdirAll(filepath.Join(home, ".overmind"), 0700) if err != nil { - fmt.Printf(": %v\n", err) - return ctx, fmt.Errorf("Error exchanging Device Code for for access token: %w", err) + log.WithContext(ctx).WithError(err).Error("Failed to create ~/.overmind directory") } - // Check that we actually got the claims we asked for. If you don't have - // permission auth0 will just not assign those scopes rather than fail - claims, err := extractClaims(token.AccessToken) - + // Write the token to a file + path := filepath.Join(home, ".overmind", "token.json") + file, err := os.Create(path) if err != nil { - return ctx, fmt.Errorf("error extracting claims from token: %w", err) + log.WithContext(ctx).WithError(err).Errorf("Failed to create token file at %v", path) } - for _, scope := range requiredScopes { - if !claims.HasScope(scope) { - return ctx, fmt.Errorf("authenticated successfully, but you don't have the required permission: '%v'", scope) - } + // Encode the token + err = json.NewEncoder(file).Encode(token) + if err != nil { + log.WithContext(ctx).WithError(err).Errorf("Failed to encode token file at %v", path) } - log.WithContext(ctx).Info("Authenticated successfully ✅") + log.WithContext(ctx).Debugf("Saved token to %v", path) + } - // Save the token locally - if home, err := os.UserHomeDir(); err == nil { - // Create the directory if it doesn't exist - err = os.MkdirAll(filepath.Join(home, ".overmind"), 0700) - if err != nil { - log.WithContext(ctx).WithError(err).Error("Failed to create ~/.overmind directory") - } + return token.AccessToken, nil +} - // Write the token to a file - path := filepath.Join(home, ".overmind", "token.json") - file, err := os.Create(path) - if err != nil { - log.WithContext(ctx).WithError(err).Errorf("Failed to create token file at %v", path) - } +// ensureToken +func ensureToken(ctx context.Context, requiredScopes []string) (context.Context, error) { + var accessToken string + var err error - // Encode the token - err = json.NewEncoder(file).Encode(token) - if err != nil { - log.WithContext(ctx).WithError(err).Errorf("Failed to encode token file at %v", path) - } + // get a token from the api key if present + if apiKey := viper.GetString("api-key"); apiKey != "" { + accessToken, err = getAPIKeyToken(ctx, apiKey) + } else { + accessToken, err = getOauthToken(ctx, requiredScopes) + } - log.WithContext(ctx).Debugf("Saved token to %v", path) - } + if err != nil { + return ctx, fmt.Errorf("error getting token: %w", err) + } + + // Check that we actually got the claims we asked for. If you don't have + // permission auth0 will just not assign those scopes rather than fail + claims, err := extractClaims(accessToken) - // Set the token - return context.WithValue(ctx, sdp.UserTokenContextKey{}, token.AccessToken), nil + if err != nil { + return ctx, fmt.Errorf("error extracting claims from token: %w", err) + } + + for _, scope := range requiredScopes { + if !claims.HasScope(scope) { + return ctx, fmt.Errorf("authenticated successfully, but you don't have the required permission: '%v'", scope) + } } - return ctx, fmt.Errorf("no OVM_API_KEY configured and target URL (%v) is insecure", parsed) + + // Add the token to the context + return context.WithValue(ctx, sdp.UserTokenContextKey{}, accessToken), nil } // getChangeUuid returns the UUID of a change, as selected by --uuid or --change, or a state with the specified status and having --ticket-link @@ -302,16 +349,16 @@ func getChangeUuid(ctx context.Context, expectedStatus sdp.ChangeStatus, errNotF // Finally look through all open changes to find one with a matching ticket link client := AuthenticatedChangesClient(ctx) - var maybeChangeUuid *uuid.UUID changesList, err := client.ListChangesByStatus(ctx, &connect.Request[sdp.ListChangesByStatusRequest]{ Msg: &sdp.ListChangesByStatusRequest{ Status: expectedStatus, }, }) if err != nil { - return uuid.Nil, fmt.Errorf("failed to searching for existing changes: %w", err) + return uuid.Nil, fmt.Errorf("failed to search for existing changes: %w", err) } + var maybeChangeUuid *uuid.UUID for _, c := range changesList.Msg.GetChanges() { if c.GetProperties().GetTicketLink() == ticketLink { maybeChangeUuid = c.GetMetadata().GetUUIDParsed()