diff --git a/cmd/root.go b/cmd/root.go index 7db00421..84836654 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,13 +7,11 @@ import ( "encoding/json" "errors" "fmt" - "net/http" "net/url" "os" "path" "path/filepath" "strings" - "time" "connectrpc.com/connect" "github.com/google/uuid" @@ -26,9 +24,6 @@ import ( "golang.org/x/oauth2" ) -const Auth0ClientId = "j3LylZtIosVPZtouKI8WuVHmE6Lluva1" -const Auth0Domain = "om-prod.eu.auth0.com" - var logLevel string //go:generate sh -c "echo -n $(git describe --tags --long) > commit.txt" @@ -147,6 +142,7 @@ func readLocalToken(homeDir string, expectedScopes []string) (string, []string, } } + log.Debugf("Using local token from %v", path) return token.AccessToken, currentScopes, nil } @@ -204,92 +200,28 @@ func ensureToken(ctx context.Context, requiredScopes []string) (context.Context, // keep replacing it requestScopes := append(requiredScopes, localScopes...) - // Authenticate using the oauth resource owner password flow + // Authenticate using the oauth device authorization flow config := oauth2.Config{ - ClientID: Auth0ClientId, - Scopes: requestScopes, + ClientID: viper.GetString("cli-auth0-client-id"), Endpoint: oauth2.Endpoint{ - AuthURL: fmt.Sprintf("https://%v/authorize", Auth0Domain), - TokenURL: fmt.Sprintf("https://%v/oauth/token", Auth0Domain), + AuthURL: fmt.Sprintf("https://%v/authorize", viper.GetString("cli-auth0-domain")), + TokenURL: fmt.Sprintf("https://%v/oauth/token", viper.GetString("cli-auth0-domain")), + DeviceAuthURL: fmt.Sprintf("https://%v/oauth/device/code", viper.GetString("cli-auth0-domain")), }, - RedirectURL: "http://127.0.0.1:7837/oauth/callback", + Scopes: requestScopes, } - tokenChan := make(chan *oauth2.Token, 1) - // create a random token for this exchange - oAuthStateString := uuid.New().String() - - // Start the web server to listen for the callback - handler := func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - queryParts, err := url.ParseQuery(r.URL.RawQuery) - if err != nil { - log.WithContext(ctx).WithError(err).WithFields(log.Fields{ - "url": r.URL, - }).Error("Failed to parse url") - } - - // Use the authorization code that is pushed to the redirect - // URL. - code := queryParts["code"][0] - log.WithContext(ctx).Debugf("Got code: %v", code) - - state := queryParts["state"][0] - log.WithContext(ctx).Debugf("Got state: %v", state) - - if state != oAuthStateString { - log.WithContext(ctx).Errorf("Invalid state, expected %v, got %v", oAuthStateString, state) - return - } - - // Exchange will do the handshake to retrieve the initial access token. - log.WithContext(ctx).Debug("Exchanging code for token") - tok, err := config.Exchange(ctx, code) - if err != nil { - log.WithContext(ctx).Error(err) - return - } - log.WithContext(ctx).Debug("Got token") - - tokenChan <- tok - - // show success page - msg := "

Success!

" - msg = msg + "

You are authenticated and can now return to the CLI.

" - fmt.Fprint(w, msg) + deviceCode, err := config.DeviceAuth(ctx, oauth2.SetAuthURLParam("audience", "https://api.overmind.tech")) + if err != nil { + return ctx, fmt.Errorf("error getting device code: %w", err) } - audienceOption := oauth2.SetAuthURLParam("audience", "https://api.overmind.tech") - - u := config.AuthCodeURL(oAuthStateString, oauth2.AccessTypeOnline, audienceOption) - log.WithContext(ctx).Infof("Follow this link to authenticate: %v", Underline.TextStyle(u)) + fmt.Printf("Go to %v and verify this code: %v\n", deviceCode.VerificationURIComplete, deviceCode.UserCode) - // Start the webserver - log.WithContext(ctx).Trace("Starting webserver to listen for callback, press Ctrl+C to cancel") - srv := &http.Server{Addr: ":7837", ReadHeaderTimeout: 30 * time.Second} - http.HandleFunc("/oauth/callback", handler) - - go func() { - if err := srv.ListenAndServe(); err != http.ErrServerClosed { - // unexpected error. port in use? - log.WithContext(ctx).Errorf("HTTP Server error: %v", err) - } - }() - - // Wait for the token or cancel - var token *oauth2.Token - select { - case token = <-tokenChan: - // Keep working - case <-ctx.Done(): - return ctx, ctx.Err() - } - - // Stop the server - err = srv.Shutdown(ctx) + token, err := config.DeviceAccessToken(ctx, deviceCode) if err != nil { - log.WithContext(ctx).WithError(err).Warn("failed to shutdown auth callback server, but continuing anyway") + fmt.Printf(": %v\n", err) + return ctx, fmt.Errorf("Error exchanging Device Code for for access token: %w", err) } // Check that we actually got the claims we asked for. If you don't have @@ -438,15 +370,17 @@ func init() { log.WithError(err).Fatal("could not bind api key to env") } - // tracing + // internal configs + rootCmd.PersistentFlags().String("cli-auth0-client-id", "QMfjMww3x4QTpeXiuRtMV3JIQkx6mZa4", "OAuth Client ID to use when connecting with auth0") + rootCmd.PersistentFlags().String("cli-auth0-domain", "om-prod.eu.auth0.com", "Auth0 domain to connect to") rootCmd.PersistentFlags().String("honeycomb-api-key", "", "If specified, configures opentelemetry libraries to submit traces to honeycomb. This requires --otel to be set.") - // Mark this as hidden. This means that it will still be parsed of supplied, + + // Mark these as hidden. This means that it will still be parsed of supplied, // and we will still look for it in the environment, but it won't be shown // in the help - err = rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key") - if err != nil { - log.WithError(err).Fatal("could not mark `honeycomb-api-key` flag as hidden") - } + must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-client-id")) + must(rootCmd.PersistentFlags().MarkHidden("cli-auth0-domain")) + must(rootCmd.PersistentFlags().MarkHidden("honeycomb-api-key")) // Create groups rootCmd.AddGroup(&cobra.Group{ @@ -502,3 +436,11 @@ func initConfig() { viper.SetEnvKeyReplacer(replacer) viper.AutomaticEnv() // read in environment variables that match } + +// must panics if the passed in error is not nil +// use this for init-time error checking of viper/cobra stuff that sometimes errors if the flag does not exist +func must(err error) { + if err != nil { + panic(fmt.Errorf("error initialising: %w", err)) + } +}