Skip to content

Commit

Permalink
Merge pull request #178 from overmindtech/device-authz
Browse files Browse the repository at this point in the history
Use Device Authorization flow
  • Loading branch information
DavidS-ovm authored Feb 15, 2024
2 parents f53b002 + 1085746 commit d1d71bf
Showing 1 changed file with 30 additions and 88 deletions.
118 changes: 30 additions & 88 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"time"

"connectrpc.com/connect"
"github.com/google/uuid"
Expand All @@ -26,9 +24,6 @@ import (
"golang.org/x/oauth2"
)

const Auth0ClientId = "j3LylZtIosVPZtouKI8WuVHmE6Lluva1"
const Auth0Domain = "om-prod.eu.auth0.com"

var logLevel string

//go:generate sh -c "echo -n $(git describe --tags --long) > commit.txt"
Expand Down Expand Up @@ -147,6 +142,7 @@ func readLocalToken(homeDir string, expectedScopes []string) (string, []string,
}
}

log.Debugf("Using local token from %v", path)
return token.AccessToken, currentScopes, nil
}

Expand Down Expand Up @@ -204,92 +200,28 @@ func ensureToken(ctx context.Context, requiredScopes []string) (context.Context,
// keep replacing it
requestScopes := append(requiredScopes, localScopes...)

// Authenticate using the oauth resource owner password flow
// Authenticate using the oauth device authorization flow
config := oauth2.Config{
ClientID: Auth0ClientId,
Scopes: requestScopes,
ClientID: viper.GetString("cli-auth0-client-id"),
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("https://%v/authorize", Auth0Domain),
TokenURL: fmt.Sprintf("https://%v/oauth/token", Auth0Domain),
AuthURL: fmt.Sprintf("https://%v/authorize", viper.GetString("cli-auth0-domain")),
TokenURL: fmt.Sprintf("https://%v/oauth/token", viper.GetString("cli-auth0-domain")),
DeviceAuthURL: fmt.Sprintf("https://%v/oauth/device/code", viper.GetString("cli-auth0-domain")),
},
RedirectURL: "http://127.0.0.1:7837/oauth/callback",
Scopes: requestScopes,
}

tokenChan := make(chan *oauth2.Token, 1)
// create a random token for this exchange
oAuthStateString := uuid.New().String()

// Start the web server to listen for the callback
handler := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

queryParts, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
log.WithContext(ctx).WithError(err).WithFields(log.Fields{
"url": r.URL,
}).Error("Failed to parse url")
}

// Use the authorization code that is pushed to the redirect
// URL.
code := queryParts["code"][0]
log.WithContext(ctx).Debugf("Got code: %v", code)

state := queryParts["state"][0]
log.WithContext(ctx).Debugf("Got state: %v", state)

if state != oAuthStateString {
log.WithContext(ctx).Errorf("Invalid state, expected %v, got %v", oAuthStateString, state)
return
}

// Exchange will do the handshake to retrieve the initial access token.
log.WithContext(ctx).Debug("Exchanging code for token")
tok, err := config.Exchange(ctx, code)
if err != nil {
log.WithContext(ctx).Error(err)
return
}
log.WithContext(ctx).Debug("Got token")

tokenChan <- tok

// show success page
msg := "<p><strong>Success!</strong></p>"
msg = msg + "<p>You are authenticated and can now return to the CLI.</p>"
fmt.Fprint(w, msg)
deviceCode, err := config.DeviceAuth(ctx, oauth2.SetAuthURLParam("audience", "https://api.overmind.tech"))
if err != nil {
return ctx, fmt.Errorf("error getting device code: %w", err)
}

audienceOption := oauth2.SetAuthURLParam("audience", "https://api.overmind.tech")

u := config.AuthCodeURL(oAuthStateString, oauth2.AccessTypeOnline, audienceOption)
log.WithContext(ctx).Infof("Follow this link to authenticate: %v", Underline.TextStyle(u))
fmt.Printf("Go to %v and verify this code: %v\n", deviceCode.VerificationURIComplete, deviceCode.UserCode)

// Start the webserver
log.WithContext(ctx).Trace("Starting webserver to listen for callback, press Ctrl+C to cancel")
srv := &http.Server{Addr: ":7837", ReadHeaderTimeout: 30 * time.Second}
http.HandleFunc("/oauth/callback", handler)

go func() {
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
// unexpected error. port in use?
log.WithContext(ctx).Errorf("HTTP Server error: %v", err)
}
}()

// Wait for the token or cancel
var token *oauth2.Token
select {
case token = <-tokenChan:
// Keep working
case <-ctx.Done():
return ctx, ctx.Err()
}

// Stop the server
err = srv.Shutdown(ctx)
token, err := config.DeviceAccessToken(ctx, deviceCode)
if err != nil {
log.WithContext(ctx).WithError(err).Warn("failed to shutdown auth callback server, but continuing anyway")
fmt.Printf(": %v\n", err)
return ctx, fmt.Errorf("Error exchanging Device Code for for access token: %w", err)
}

// Check that we actually got the claims we asked for. If you don't have
Expand Down Expand Up @@ -438,15 +370,17 @@ func init() {
log.WithError(err).Fatal("could not bind api key to env")
}

// tracing
// internal configs
rootCmd.PersistentFlags().String("cli-auth0-client-id", "QMfjMww3x4QTpeXiuRtMV3JIQkx6mZa4", "OAuth Client ID to use when connecting with auth0")
rootCmd.PersistentFlags().String("cli-auth0-domain", "om-prod.eu.auth0.com", "Auth0 domain to connect to")
rootCmd.PersistentFlags().String("honeycomb-api-key", "", "If specified, configures opentelemetry libraries to submit traces to honeycomb. This requires --otel to be set.")
// Mark this as hidden. This means that it will still be parsed of supplied,

// Mark these as hidden. This means that it will still be parsed of supplied,
// and we will still look for it in the environment, but it won't be shown
// in the help
err = rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key")
if err != nil {
log.WithError(err).Fatal("could not mark `honeycomb-api-key` flag as hidden")
}
must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-client-id"))
must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-domain"))
must(rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key"))

// Create groups
rootCmd.AddGroup(&cobra.Group{
Expand Down Expand Up @@ -502,3 +436,11 @@ func initConfig() {
viper.SetEnvKeyReplacer(replacer)
viper.AutomaticEnv() // read in environment variables that match
}

// must panics if the passed in error is not nil
// use this for init-time error checking of viper/cobra stuff that sometimes errors if the flag does not exist
func must(err error) {
if err != nil {
panic(fmt.Errorf("error initialising: %w", err))
}
}

0 comments on commit d1d71bf

Please sign in to comment.