Skip to content

Commit

Permalink
cmd/cue: fix "login" on a clean environment, and add tests
Browse files Browse the repository at this point in the history
`cue login` broke as it treated a missing $CUE_CONFIG_DIR/logins.cue
as a fatal error, which is bad as new users are in this situation.
We never noticed as we didn't have integration tests for this command.

Write the first integration tests with a mock httptest server
which only implements the two OAuth2 endpoints with simple logic.
Rather than having to write the testscripts to emulate a user
going to the /login/device page to insert the user_code string,
run the oauth registry in various modes with predefined behaviors.

Note that I had to teach `cue login` to perform the OAuth2 device flow
over HTTP rather than HTTPS when using an insecure CUE registry.
This required changing the AllHosts method from returning []string,
which was done to avoid exposing the Host type, to returning []Host.
Both the Oauth2 config and the logins.json transport are updated.

Fixes #2925.

Signed-off-by: Daniel Martí <[email protected]>
Change-Id: Ib793fb73c921cd68038645d072271fbe9e8ee0ec
Dispatch-Trailer: {"type":"trybot","CL":1178176,"patchset":2,"ref":"refs/changes/76/1178176/2","targetBranch":"master"}
  • Loading branch information
mvdan authored and cueckoo committed Mar 12, 2024
1 parent 6db0b7f commit ccc8c22
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 22 deletions.
15 changes: 10 additions & 5 deletions cmd/cue/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ package cmd

import (
"context"
"errors"
"fmt"
"io/fs"
"net/http"
"os"

Expand Down Expand Up @@ -80,16 +82,19 @@ inside your user's config directory, such as $XDG_CONFIG_HOME or %AppData%.
if len(registryHosts) > 1 {
return fmt.Errorf("need a single CUE registry to log into")
}
registry := registryHosts[0]
host := registryHosts[0]
loginsPath, err := cueconfig.LoginConfigPath(os.Getenv)
if err != nil {
return fmt.Errorf("cannot find the path to store CUE registry logins: %v", err)
}
logins, err := cueconfig.ReadLogins(loginsPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
// No config file yet; create an empty one.
logins = &cueconfig.Logins{Registries: make(map[string]cueconfig.RegistryLogin)}
} else if err != nil {
return fmt.Errorf("cannot load CUE registry logins: %v", err)
}
oauthCfg := cueconfig.RegistryOAuthConfig(registry)
oauthCfg := cueconfig.RegistryOAuthConfig(host)

resp, err := oauthCfg.DeviceAuth(ctx)
if err != nil {
Expand All @@ -105,12 +110,12 @@ inside your user's config directory, such as $XDG_CONFIG_HOME or %AppData%.
return fmt.Errorf("cannot obtain the OAuth2 token: %v", err)
}

logins.Registries[registry] = cueconfig.LoginFromToken(tok)
logins.Registries[host.Name] = cueconfig.LoginFromToken(tok)

if err := cueconfig.WriteLogins(loginsPath, logins); err != nil {
return fmt.Errorf("cannot store CUE registry logins: %v", err)
}
fmt.Printf("Login for %s stored in %s\n", registry, loginsPath)
fmt.Printf("Login for %s stored in %s\n", host.Name, loginsPath)
// TODO: Once we support encryption, we should print a warning if it's not available.
return nil
}),
Expand Down
107 changes: 107 additions & 0 deletions cmd/cue/cmd/script_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io/fs"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"time"

Expand All @@ -35,6 +38,7 @@ import (
"github.com/rogpeppe/go-internal/goproxytest"
"github.com/rogpeppe/go-internal/gotooltest"
"github.com/rogpeppe/go-internal/testscript"
"golang.org/x/oauth2"
"golang.org/x/tools/txtar"

"cuelang.org/go/cue/errors"
Expand Down Expand Up @@ -143,6 +147,24 @@ func TestScript(t *testing.T) {
ts.Setenv(args[0], u.Host)
ts.Defer(srv.Close)
},
// memregistry starts an HTTP server with enough endpoints to test `cue login`.
// It takes a single argument to describe the oauth server's behavior:
//
// * device-code-expired: polling for a token with device_code
// always responds with [tokenErrorCodeExpired]
// * pending-success: polling for a token with device_code
// responds with [tokenErrorCodePending] once, and then succeeds
// * immediate-success: polling for a token with device_code succeeds right away
"oauthregistry": func(ts *testscript.TestScript, neg bool, args []string) {
if len(args) != 1 {
ts.Fatalf("usage: oauthregistry <mode>")
}
ts.Setenv("CUE_EXPERIMENT", "modules")
srv := newMockRegistryOauth(args[0])
u, _ := url.Parse(srv.URL)
ts.Setenv("CUE_REGISTRY", u.Host+"+insecure")
ts.Defer(srv.Close)
},
},
Setup: func(e *testscript.Env) error {
// If a testscript loads CUE packages but forgot to set up a cue.mod,
Expand Down Expand Up @@ -355,3 +377,88 @@ func testCmd() error {
return fmt.Errorf("unknown command: %q\n", cmd)
}
}

// newMockRegistryOauth starts a test HTTP server with the OAuth2 device flow endpoints
// used by `cue login` to obtain an access token.
// Note that this HTTP server isn't an OCI registry yet, as that isn't needed for now.
//
// TODO: once we support refresh tokens, add those endpoints and test them too.
func newMockRegistryOauth(mode string) *httptest.Server {
mux := http.NewServeMux()
ts := httptest.NewServer(mux)
const (
staticUserCode = "user-code"
staticDeviceCode = "device-code-longer-string"
staticAccessToken = "secret-access-token"
intervalSecs = 1 // 1s to keep the tests fast
)
// OAuth2 Device Authorization Request endpoint: https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
mux.HandleFunc("/login/device/code", func(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, oauth2.DeviceAuthResponse{
DeviceCode: staticDeviceCode,
UserCode: staticUserCode,

VerificationURI: ts.URL + "/login/device",
VerificationURIComplete: ts.URL + "/login/device?user_code=" + url.QueryEscape(staticUserCode),

Expiry: time.Now().Add(time.Minute),
Interval: intervalSecs,
})
})
// OAuth2 Token endpoint: https://datatracker.ietf.org/doc/html/rfc6749#section-3.2
var tokenRequestCounter atomic.Int64
mux.HandleFunc("/login/oauth/token", func(w http.ResponseWriter, r *http.Request) {
deviceCode := r.FormValue("device_code")
if deviceCode != staticDeviceCode {
writeJSON(w, http.StatusBadRequest, tokenError{ErrorCode: tokenErrorCodeDenied})
return
}
switch mode {
case "device-code-expired":
writeJSON(w, http.StatusBadRequest, tokenError{ErrorCode: tokenErrorCodeExpired})
case "pending-success":
count := tokenRequestCounter.Add(1)
if count == 1 {
writeJSON(w, http.StatusBadRequest, tokenError{ErrorCode: tokenErrorCodePending})
break
}
fallthrough
case "immediate-success":
writeJSON(w, http.StatusOK, oauth2.Token{
AccessToken: staticAccessToken,
TokenType: "Bearer",
Expiry: time.Now().Add(time.Hour),
})
default:
panic(fmt.Sprintf("unknown mode: %q", mode))
}
})
return ts
}

func writeJSON(w http.ResponseWriter, statusCode int, v any) {
b, err := json.Marshal(v)
if err != nil { // should never happen
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
w.Write(b)
}

const (
// Device flow token error code strings from https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
tokenErrorCodePending = "authorization_pending" // waiting for user
tokenErrorCodeSlowDown = "slow_down" // increase polling interval
tokenErrorCodeDenied = "access_denied" // the user denied the request
tokenErrorCodeExpired = "expired_token" // the device_code expired
)

// tokenError implements the error response structure defined by
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
type tokenError struct {
ErrorCode string `json:"error"` // one of the constants above
ErrorDescription string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
}
10 changes: 10 additions & 0 deletions cmd/cue/cmd/testdata/script/login_expired.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Test that `cue login` fails when given an error
# such as a device_code being expired.

env CUE_CONFIG_DIR=$WORK/cueconfig
oauthregistry device-code-expired

! exec cue login
stdout 'open:.*user_code=user-code'
stderr 'cannot obtain the OAuth2 token.*expired_token'
! exists cueconfig/logins.json
9 changes: 9 additions & 0 deletions cmd/cue/cmd/testdata/script/login_immediate.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test that `cue login` succeeds with the device flow
# when the device login is immediately authorized and successful.

env CUE_CONFIG_DIR=$WORK/cueconfig
oauthregistry immediate-success

exec cue login
stdout 'open:.*user_code=user-code'
grep 'secret-access-token' cueconfig/logins.json
9 changes: 9 additions & 0 deletions cmd/cue/cmd/testdata/script/login_pending.txtar
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Test that `cue login` succeeds with the device flow
# when the device login is authorized after the first polling request.

env CUE_CONFIG_DIR=$WORK/cueconfig
oauthregistry pending-success

exec cue login
stdout 'open:.*user_code=user-code'
grep 'secret-access-token' cueconfig/logins.json
11 changes: 8 additions & 3 deletions internal/cueconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"path/filepath"
"time"

"cuelang.org/go/internal/mod/modresolve"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -101,18 +102,22 @@ func WriteLogins(path string, logins *Logins) error {

// RegistryOAuthConfig returns the oauth2 configuration
// suitable for talking to the central registry.
func RegistryOAuthConfig(host string) oauth2.Config {
func RegistryOAuthConfig(host modresolve.Host) oauth2.Config {
// For now, we use the OAuth endpoints as implemented by registry.cue.works,
// but other OCI registries may support the OAuth device flow with different ones.
//
// TODO: Query /.well-known/oauth-authorization-server to obtain
// token_endpoint and device_authorization_endpoint per the Oauth RFCs:
// * https://datatracker.ietf.org/doc/html/rfc8414#section-3
// * https://datatracker.ietf.org/doc/html/rfc8628#section-4
scheme := "https://"
if host.Insecure {
scheme = "http://"
}
return oauth2.Config{
Endpoint: oauth2.Endpoint{
DeviceAuthURL: "https://" + host + "/login/device/code",
TokenURL: "https://" + host + "/login/oauth/token",
DeviceAuthURL: scheme + host.Name + "/login/device/code",
TokenURL: scheme + host.Name + "/login/oauth/token",
},
}
}
Expand Down
22 changes: 12 additions & 10 deletions mod/modconfig/modconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,14 @@ func NewResolver(cfg *Config) (*Resolver, error) {
}, nil
}

// AllHosts returns information on all the registry host names referred to
// by the resolver.
func (r *Resolver) AllHosts() []string {
allHosts := r.resolver.AllHosts()
names := make([]string, len(allHosts))
for i, h := range allHosts {
names[i] = h.Name
}
return names
// Host represents a registry host name and whether
// it should be accessed via a secure connection or not.
type Host = modresolve.Host

// AllHosts returns all the registry hosts that the resolver might resolve to,
// ordered lexically by hostname.
func (r *Resolver) AllHosts() []Host {
return r.resolver.AllHosts()
}

// HostLocation represents a registry host and a location with it.
Expand Down Expand Up @@ -228,7 +227,10 @@ func (t *cueLoginsTransport) RoundTrip(req *http.Request) (*http.Response, error
transport := t.cachedTransports[host]
if transport == nil {
tok := cueconfig.TokenFromLogin(login)
oauthCfg := cueconfig.RegistryOAuthConfig(host)
oauthCfg := cueconfig.RegistryOAuthConfig(Host{
Name: host,
Insecure: req.URL.Scheme == "http",
})
// TODO: When this client refreshes an access token,
// we should store the refreshed token on disk.

Expand Down
11 changes: 7 additions & 4 deletions mod/modconfig/modconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
"os"
"path"
"path/filepath"
"sort"
"slices"
"strings"
"sync/atomic"
"testing"

Expand Down Expand Up @@ -111,9 +112,11 @@ package x
resolver, err := NewResolver(nil)
qt.Assert(t, qt.IsNil(err))
gotAllHosts := resolver.AllHosts()
wantAllHosts := []string{r1.Host(), r2.Host()}
sort.Strings(gotAllHosts)
sort.Strings(wantAllHosts)
wantAllHosts := []Host{{Name: r1.Host(), Insecure: true}, {Name: r2.Host(), Insecure: true}}

byHostname := func(a, b Host) int { return strings.Compare(a.Name, b.Name) }
slices.SortFunc(gotAllHosts, byHostname)
slices.SortFunc(wantAllHosts, byHostname)

qt.Assert(t, qt.DeepEquals(gotAllHosts, wantAllHosts))

Expand Down

0 comments on commit ccc8c22

Please sign in to comment.