diff --git a/admin/database/database.go b/admin/database/database.go index 9b3441d99cd..780279330b4 100644 --- a/admin/database/database.go +++ b/admin/database/database.go @@ -153,6 +153,11 @@ type DB interface { UpdateDeviceAuthCode(ctx context.Context, id, userID string, state DeviceAuthCodeState) error DeleteExpiredDeviceAuthCodes(ctx context.Context, retention time.Duration) error + FindAuthorizationCode(ctx context.Context, code string) (*AuthorizationCode, error) + InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*AuthorizationCode, error) + DeleteAuthorizationCode(ctx context.Context, code string) error + DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error + FindOrganizationRole(ctx context.Context, name string) (*OrganizationRole, error) FindProjectRole(ctx context.Context, name string) (*ProjectRole, error) ResolveOrganizationRolesForUser(ctx context.Context, userID, orgID string) ([]*OrganizationRole, error) @@ -512,9 +517,10 @@ type AuthClient struct { // Hard-coded auth client IDs (created in the migrations). const ( - AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001" - AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002" - AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003" + AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001" + AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002" + AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003" + AuthClientIDRillWebLocal = "12345678-0000-0000-0000-000000000004" ) // DeviceAuthCodeState is an enum representing the approval state of a DeviceAuthCode @@ -540,6 +546,20 @@ type DeviceAuthCode struct { UpdatedOn time.Time `db:"updated_on"` } +// AuthorizationCode represents an authorization code used for OAuth2 PKCE auth flow. +type AuthorizationCode struct { + ID string `db:"id"` + Code string `db:"code"` + UserID string `db:"user_id"` + ClientID string `db:"client_id"` + RedirectURI string `db:"redirect_uri"` + CodeChallenge string `db:"code_challenge"` + CodeChallengeMethod string `db:"code_challenge_method"` + Expiration time.Time `db:"expires_on"` + CreatedOn time.Time `db:"created_on"` + UpdatedOn time.Time `db:"updated_on"` +} + // Constants for known role names (created in migrations). const ( OrganizationRoleNameAdmin = "admin" diff --git a/admin/database/postgres/migrations/0028.sql b/admin/database/postgres/migrations/0028.sql new file mode 100644 index 00000000000..95820315efc --- /dev/null +++ b/admin/database/postgres/migrations/0028.sql @@ -0,0 +1,20 @@ +-- Hard-coded first-party auth clients +INSERT INTO auth_clients (id, display_name) +VALUES ('12345678-0000-0000-0000-000000000004', 'Rill Localhost'); + +-- Table for storing authorization codes for PKCE auth flow +CREATE TABLE authorization_codes ( + id UUID DEFAULT uuid_generate_v4() PRIMARY KEY, + code TEXT NOT NULL, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + client_id UUID NOT NULL REFERENCES auth_clients(id) ON DELETE CASCADE, + redirect_uri TEXT NOT NULL, + code_challenge TEXT NOT NULL, + code_challenge_method TEXT NOT NULL, + expires_on TIMESTAMP NOT NULL, + created_on TIMESTAMPTZ DEFAULT now() NOT NULL, + updated_on TIMESTAMPTZ DEFAULT now() NOT NULL +); + +-- create index on code column +CREATE UNIQUE INDEX authorization_codes_code_idx ON authorization_codes(code); \ No newline at end of file diff --git a/admin/database/postgres/postgres.go b/admin/database/postgres/postgres.go index a1df1bca3a2..3bba28010f3 100644 --- a/admin/database/postgres/postgres.go +++ b/admin/database/postgres/postgres.go @@ -1006,6 +1006,36 @@ func (c *connection) DeleteExpiredDeviceAuthCodes(ctx context.Context, retention return parseErr("device auth code", err) } +func (c *connection) FindAuthorizationCode(ctx context.Context, code string) (*database.AuthorizationCode, error) { + authCode := &database.AuthorizationCode{} + err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM authorization_codes WHERE code = $1", code).StructScan(authCode) + if err != nil { + return nil, parseErr("authorization code", err) + } + return authCode, nil +} + +func (c *connection) InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*database.AuthorizationCode, error) { + res := &database.AuthorizationCode{} + err := c.getDB(ctx).QueryRowxContext(ctx, + `INSERT INTO authorization_codes (code, user_id, client_id, redirect_uri, code_challenge, code_challenge_method, expires_on) + VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *`, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod, expiration).StructScan(res) + if err != nil { + return nil, parseErr("authorization code", err) + } + return res, nil +} + +func (c *connection) DeleteAuthorizationCode(ctx context.Context, code string) error { + res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE code=$1", code) + return checkDeleteRow("authorization code", res, err) +} + +func (c *connection) DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error { + _, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE expires_on + $1 < now()", retention) + return parseErr("authorization code", err) +} + func (c *connection) FindOrganizationRole(ctx context.Context, name string) (*database.OrganizationRole, error) { role := &database.OrganizationRole{} err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM org_roles WHERE lower(name)=lower($1)", name).StructScan(role) diff --git a/admin/pkg/oauth/utils.go b/admin/pkg/oauth/utils.go new file mode 100644 index 00000000000..2f7aca29a54 --- /dev/null +++ b/admin/pkg/oauth/utils.go @@ -0,0 +1,14 @@ +package oauth + +const ( + FormMediaType = "application/x-www-form-urlencoded" + JSONMediaType = "application/json" +) + +// TokenResponse contains the information returned after fetching an access token from the OAuth server. +type TokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in,string"` + TokenType string `json:"token_type"` + UserID string `json:"user_id"` +} diff --git a/admin/server/auth/device_code.go b/admin/server/auth/device_code.go index 4621911184d..1e0fd741339 100644 --- a/admin/server/auth/device_code.go +++ b/admin/server/auth/device_code.go @@ -12,8 +12,8 @@ import ( "github.com/rilldata/rill/admin" "github.com/rilldata/rill/admin/database" + "github.com/rilldata/rill/admin/pkg/oauth" "github.com/rilldata/rill/admin/pkg/urlutil" - "github.com/rilldata/rill/cli/pkg/deviceauth" ) const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" @@ -175,23 +175,8 @@ func (a *Authenticator) handleUserCodeConfirmation(w http.ResponseWriter, r *htt } } -// getAccessToken verifies the device code and returns an access token if the request is approved -func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "expected a POST request", http.StatusBadRequest) - return - } - body, err := io.ReadAll(r.Body) - if err != nil { - internalServerError(w, fmt.Errorf("failed to read request body: %w", err)) - return - } - bodyStr := string(body) - values, err := url.ParseQuery(bodyStr) - if err != nil { - internalServerError(w, fmt.Errorf("failed to parse query: %w", err)) - return - } +// getAccessTokenForDeviceCode verifies the device code and returns an access token if the request is approved +func (a *Authenticator) getAccessTokenForDeviceCode(w http.ResponseWriter, r *http.Request, values url.Values) { deviceCode := values.Get("device_code") if deviceCode == "" { http.Error(w, "device_code is required", http.StatusBadRequest) @@ -253,10 +238,10 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) { return } - resp := deviceauth.OAuthTokenResponse{ + resp := oauth.TokenResponse{ AccessToken: authToken.Token().String(), TokenType: "Bearer", - ExpiresIn: time.UnixMilli(0).Unix(), // never expires + ExpiresIn: 0, // never expires UserID: *authCode.UserID, } respBytes, err := json.Marshal(resp) diff --git a/admin/server/auth/handlers.go b/admin/server/auth/handlers.go index d83804f3ec9..f0d235bf953 100644 --- a/admin/server/auth/handlers.go +++ b/admin/server/auth/handlers.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "errors" "fmt" + "io" "net/http" "net/url" "strconv" @@ -56,7 +57,8 @@ func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux, limiter ratelimit. observability.MuxHandle(inner, "/auth/logout", middleware.Check(checkLimit("/auth/logout"), http.HandlerFunc(a.authLogout))) observability.MuxHandle(inner, "/auth/logout/callback", middleware.Check(checkLimit("/auth/logout/callback"), http.HandlerFunc(a.authLogoutCallback))) observability.MuxHandle(inner, "/auth/oauth/device_authorization", middleware.Check(checkLimit("/auth/oauth/device_authorization"), http.HandlerFunc(a.handleDeviceCodeRequest))) - observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware + observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware + observability.MuxHandle(inner, "/auth/oauth/authorize", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/authorize"), http.HandlerFunc(a.handleAuthorizeRequest)))) // NOTE: Uses auth middleware observability.MuxHandle(inner, "/auth/oauth/token", middleware.Check(checkLimit("/auth/oauth/token"), http.HandlerFunc(a.getAccessToken))) mux.Handle("/auth/", observability.Middleware("admin", a.logger, inner)) } @@ -355,3 +357,82 @@ func (a *Authenticator) authLogoutCallback(w http.ResponseWriter, r *http.Reques // Redirect to UI (usually) http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } + +// handleAuthorizeRequest handles the incoming OAuth2 Authorization request, if he user is not logged redirect to login, currently only PKCE based authorization code flow is supported +func (a *Authenticator) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) { + claims := GetClaims(r.Context()) + if claims == nil { + internalServerError(w, fmt.Errorf("did not find any claims, %w", errors.New("server error"))) + return + } + if claims.OwnerType() == OwnerTypeAnon { + // not logged in, redirect to login + // after login redirect back to same path so encode the current URL as a redirect parameter + encodedURL := url.QueryEscape(r.URL.String()) + http.Redirect(w, r, "/auth/login?redirect="+encodedURL, http.StatusTemporaryRedirect) + } + if claims.OwnerType() != OwnerTypeUser { + http.Error(w, "only users can be authorized", http.StatusBadRequest) + return + } + userID := claims.OwnerID() + + // Extract necessary details from the query parameters + clientID := r.URL.Query().Get("client_id") + redirectURI := r.URL.Query().Get("redirect_uri") + responseType := r.URL.Query().Get("response_type") + + if clientID == "" || redirectURI == "" || responseType == "" { + http.Error(w, "Missing required parameters - client_id or redirect_uri or response_type", http.StatusBadRequest) + return + } + + codeChallenge := r.URL.Query().Get("code_challenge") + codeChallengeMethod := r.URL.Query().Get("code_challenge_method") + + if codeChallenge != "" { + if codeChallengeMethod == "" { + http.Error(w, "Missing code challenge method", http.StatusBadRequest) + return + } + if responseType != "code" { + http.Error(w, "Invalid response type", http.StatusBadRequest) + return + } + a.handlePKCE(w, r, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI) + } else { + http.Error(w, "only PKCE based authorization code flow is supported", http.StatusBadRequest) + return + } +} + +// getAccessToken depending on the grant_type either verifies the device code and returns an access token if the request is approved or exchanges the authorization code for an access token +func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "expected a POST request", http.StatusBadRequest) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + internalServerError(w, fmt.Errorf("failed to read request body: %w", err)) + return + } + bodyStr := string(body) + values, err := url.ParseQuery(bodyStr) + if err != nil { + internalServerError(w, fmt.Errorf("failed to parse query: %w", err)) + return + } + + grantType := values.Get("grant_type") + if !(grantType == deviceCodeGrantType || grantType == authorizationCodeGrantType) { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + if grantType == deviceCodeGrantType { + a.getAccessTokenForDeviceCode(w, r, values) + } else { + a.getAccessTokenForAuthorizationCode(w, r, values) + } +} diff --git a/admin/server/auth/pkce.go b/admin/server/auth/pkce.go new file mode 100644 index 00000000000..e43f02ceb58 --- /dev/null +++ b/admin/server/auth/pkce.go @@ -0,0 +1,175 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "time" + + "github.com/rilldata/rill/admin/database" + "github.com/rilldata/rill/admin/pkg/oauth" +) + +const authorizationCodeGrantType = "authorization_code" + +func (a *Authenticator) handlePKCE(w http.ResponseWriter, r *http.Request, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI string) { + // Generate a unique authorization code + code, err := generateRandomString(16) // 16 bytes, resulting in a 32-character hex string + if err != nil { + http.Error(w, "Failed to generate authorization code", http.StatusInternalServerError) + return + } + + // Set the expiration time for the authorization code to a minute from now. Note from https://www.oauth.com/oauth2-servers/authorization/the-authorization-response/ - + // The authorization code must expire shortly after it is issued. The OAuth 2.0 spec recommends a maximum lifetime of 10 minutes, but in practice, most services set the expiration much shorter, around 30-60 seconds. + expiration := time.Now().Add(1 * time.Minute) + + // Store the authorization code in the database + _, err = a.admin.DB.InsertAuthorizationCode(r.Context(), code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod, expiration) + if err != nil { + internalServerError(w, fmt.Errorf("failed to store authorization code, %w", err)) + return + } + + // Build the redirection URI with the authorization code as per OAuth2 spec, state is URL-encoded + redirectWithCode := fmt.Sprintf("%s?code=%s&state=%s", redirectURI, code, r.URL.Query().Get("state")) + + // Redirect the user agent to the redirect URI with the authorization code + http.Redirect(w, r, redirectWithCode, http.StatusFound) +} + +// getAccessTokenForAuthorizationCode exchanges an authorization code for an access token +func (a *Authenticator) getAccessTokenForAuthorizationCode(w http.ResponseWriter, r *http.Request, values url.Values) { + // Extract the authorization code + code := values.Get("code") + if code == "" { + http.Error(w, "authorization code is required", http.StatusBadRequest) + return + } + + // Extract the client ID + clientID := values.Get("client_id") + if clientID == "" { + http.Error(w, "client ID is required", http.StatusBadRequest) + return + } + + // Extract the redirect URI + redirectURI := values.Get("redirect_uri") + if redirectURI == "" { + http.Error(w, "redirect URI is required", http.StatusBadRequest) + return + } + + // Extract the code verifier + codeVerifier := values.Get("code_verifier") + if codeVerifier == "" { + http.Error(w, "code verifier is required", http.StatusBadRequest) + return + } + + // get the authorization code from the database + authCode, err := a.admin.DB.FindAuthorizationCode(r.Context(), code) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + http.Error(w, "no such authorization code found", http.StatusBadRequest) + } else { + internalServerError(w, fmt.Errorf("failed to get authorization code, %w", err)) + } + return + } + + userID := authCode.UserID + if userID == "" { + http.Error(w, "no user found for authorization code", http.StatusInternalServerError) + return + } + + // remove the authorization code from the database to prevent reuse + err = a.admin.DB.DeleteAuthorizationCode(r.Context(), code) + if err != nil { + internalServerError(w, fmt.Errorf("failed to delete authorization code, %w", err)) + return + } + + // Check if the client ID matches the stored client ID + if authCode.ClientID != clientID { + http.Error(w, "invalid client ID", http.StatusBadRequest) + return + } + + // Check if the redirect URI matches the stored redirect URI + if authCode.RedirectURI != redirectURI { + http.Error(w, "invalid redirect URI", http.StatusBadRequest) + return + } + + // Check if the authorization code has expired + if time.Now().After(authCode.Expiration) { + http.Error(w, "authorization code has expired", http.StatusBadRequest) + return + } + + // Verify the code verifier against the stored code challenge + if !verifyCodeChallenge(codeVerifier, authCode.CodeChallenge, authCode.CodeChallengeMethod) { + http.Error(w, "invalid code verifier", http.StatusBadRequest) + return + } + + // Issue an access token + authToken, err := a.admin.IssueUserAuthToken(r.Context(), userID, authCode.ClientID, "", nil, nil) + if err != nil { + if errors.Is(err, r.Context().Err()) { + http.Error(w, "request cancelled or timeout", http.StatusRequestTimeout) + return + } + internalServerError(w, fmt.Errorf("failed to issue access token, %w", err)) + return + } + + resp := oauth.TokenResponse{ + AccessToken: authToken.Token().String(), + TokenType: "Bearer", + ExpiresIn: 0, // never expires + UserID: userID, + } + respBytes, err := json.Marshal(resp) + if err != nil { + internalServerError(w, fmt.Errorf("failed to marshal response, %w", err)) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(respBytes) + if err != nil { + internalServerError(w, fmt.Errorf("failed to write response, %w", err)) + return + } +} + +// verifyCodeChallenge validates the code verifier with the stored code challenge +func verifyCodeChallenge(verifier, challenge, method string) bool { + switch method { + case "S256": + s256 := sha256.Sum256([]byte(verifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(s256[:]) + return computedChallenge == challenge + default: + return false + } +} + +// Generates a random string for use as the authorization code +func generateRandomString(n int) (string, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/admin/worker/delete_expired_auth_codes.go b/admin/worker/delete_expired_auth_codes.go new file mode 100644 index 00000000000..9e4177fd498 --- /dev/null +++ b/admin/worker/delete_expired_auth_codes.go @@ -0,0 +1,13 @@ +package worker + +import ( + "context" + "time" +) + +func (w *Worker) deleteExpiredAuthCodes(ctx context.Context) error { + // Delete auth codes that have been expired for more than 24 hours. + // By delaying deletion past the expiration time, we can provide a nicer error message for expired codes. + // (The user will see "code has expired" instead of "code not found".) + return w.admin.DB.DeleteExpiredAuthorizationCodes(ctx, 24*time.Hour) +} diff --git a/admin/worker/worker.go b/admin/worker/worker.go index 606754746b4..7370beea082 100644 --- a/admin/worker/worker.go +++ b/admin/worker/worker.go @@ -47,6 +47,9 @@ func (w *Worker) Run(ctx context.Context) error { group.Go(func() error { return w.schedule(ctx, "delete_expired_device_auth_codes", w.deleteExpiredDeviceAuthCodes, 6*time.Hour) }) + group.Go(func() error { + return w.schedule(ctx, "delete_expired_auth_codes", w.deleteExpiredAuthCodes, 6*time.Hour) + }) group.Go(func() error { return w.schedule(ctx, "delete_expired_virtual_files", w.deleteExpiredVirtualFiles, 6*time.Hour) }) diff --git a/cli/cmd/start/start.go b/cli/cmd/start/start.go index 5fcb4ec9aec..41431a68a07 100644 --- a/cli/cmd/start/start.go +++ b/cli/cmd/start/start.go @@ -148,6 +148,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command { Activity: ch.Telemetry(cmd.Context()), AdminURL: ch.AdminURL, AdminToken: ch.AdminToken(), + CMDHelper: ch, }) if err != nil { return err diff --git a/cli/pkg/deviceauth/authenticator.go b/cli/pkg/deviceauth/authenticator.go index 7c732f6bcb0..82a840c6591 100644 --- a/cli/pkg/deviceauth/authenticator.go +++ b/cli/pkg/deviceauth/authenticator.go @@ -15,15 +15,11 @@ import ( "github.com/benbjohnson/clock" "github.com/rilldata/rill/admin/database" + "github.com/rilldata/rill/admin/pkg/oauth" ) // Most parts of this file are copied from https://github.com/planetscale/cli/blob/main/internal/auth/authenticator.go -const ( - formMediaType = "application/x-www-form-urlencoded" - jsonMediaType = "application/json" -) - var ( ErrAuthenticationTimedout = fmt.Errorf("authentication timed out") ErrCodeRejected = fmt.Errorf("confirmation code rejected") @@ -125,7 +121,7 @@ func (d *DeviceAuthenticator) VerifyDevice(ctx context.Context, redirectURL stri } // GetAccessTokenForDevice uses the device verification response to fetch an access token. -func (d *DeviceAuthenticator) GetAccessTokenForDevice(ctx context.Context, v *DeviceVerification) (*OAuthTokenResponse, error) { +func (d *DeviceAuthenticator) GetAccessTokenForDevice(ctx context.Context, v *DeviceVerification) (*oauth.TokenResponse, error) { for { // This loop begins right after we open the user's browser to send an // authentication code. We don't request a token immediately because the @@ -154,15 +150,7 @@ func (d *DeviceAuthenticator) GetAccessTokenForDevice(ctx context.Context, v *De } } -// OAuthTokenResponse contains the information returned after fetching an access token for a device. -type OAuthTokenResponse struct { - AccessToken string `json:"access_token"` - ExpiresIn int64 `json:"expires_in,string"` - TokenType string `json:"token_type"` - UserID string `json:"user_id"` -} - -func (d *DeviceAuthenticator) requestToken(ctx context.Context, deviceCode, clientID string) (*OAuthTokenResponse, error) { +func (d *DeviceAuthenticator) requestToken(ctx context.Context, deviceCode, clientID string) (*oauth.TokenResponse, error) { req, err := d.newFormRequest(ctx, "auth/oauth/token", url.Values{ "grant_type": []string{"urn:ietf:params:oauth:grant-type:device_code"}, "device_code": []string{deviceCode}, @@ -188,7 +176,7 @@ func (d *DeviceAuthenticator) requestToken(ctx context.Context, deviceCode, clie return nil, nil } - tokenRes := &OAuthTokenResponse{} + tokenRes := &oauth.TokenResponse{} err = json.NewDecoder(res.Body).Decode(tokenRes) if err != nil { return nil, fmt.Errorf("error decoding token response: %w", err) @@ -216,8 +204,8 @@ func (d *DeviceAuthenticator) newFormRequest(ctx context.Context, path string, p return nil, err } - req.Header.Set("Content-Type", formMediaType) - req.Header.Set("Accept", jsonMediaType) + req.Header.Set("Content-Type", oauth.FormMediaType) + req.Header.Set("Accept", oauth.JSONMediaType) return req, nil } diff --git a/cli/pkg/local/app.go b/cli/pkg/local/app.go index 90c8bc8a782..a022d037616 100644 --- a/cli/pkg/local/app.go +++ b/cli/pkg/local/app.go @@ -2,7 +2,9 @@ package local import ( "context" + "crypto/rand" "crypto/tls" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -16,9 +18,11 @@ import ( "github.com/bmatcuk/doublestar/v4" "github.com/c2h5oh/datasize" + "github.com/rilldata/rill/admin/database" "github.com/rilldata/rill/cli/pkg/browser" "github.com/rilldata/rill/cli/pkg/cmdutil" "github.com/rilldata/rill/cli/pkg/dotrill" + "github.com/rilldata/rill/cli/pkg/pkce" "github.com/rilldata/rill/cli/pkg/update" "github.com/rilldata/rill/cli/pkg/web" runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1" @@ -73,6 +77,9 @@ type App struct { observabilityShutdown observability.ShutdownFunc loggerCleanUp func() activity *activity.Client + adminURL string + pkceAuthenticators map[string]*pkce.Authenticator // map of state to pkce authenticators + ch *cmdutil.Helper } type AppOptions struct { @@ -89,6 +96,7 @@ type AppOptions struct { Activity *activity.Client AdminURL string AdminToken string + CMDHelper *cmdutil.Helper } func NewApp(ctx context.Context, opts *AppOptions) (*App, error) { @@ -299,6 +307,9 @@ func NewApp(ctx context.Context, opts *AppOptions) (*App, error) { observabilityShutdown: shutdown, loggerCleanUp: cleanupFn, activity: opts.Activity, + adminURL: opts.AdminURL, + pkceAuthenticators: make(map[string]*pkce.Authenticator), + ch: opts.CMDHelper, } // Collect and emit information about connectors at start time @@ -382,6 +393,9 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool return runtimeServer.ServeGRPC(ctx) }) + // if keypath and certpath are provided + secure := tlsCertPath != "" && tlsKeyPath != "" + // Start the local HTTP server group.Go(func() error { return runtimeServer.ServeHTTP(ctx, func(mux *http.ServeMux) { @@ -392,6 +406,8 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool mux.Handle("/local/config", a.infoHandler(inf)) mux.Handle("/local/version", a.versionHandler()) mux.Handle("/local/track", a.trackingHandler()) + mux.Handle("/auth", a.initiateAuthFlow(httpPort, secure)) + mux.Handle("/auth/callback", a.handleAuthCallback()) }) }) @@ -400,9 +416,6 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool group.Go(func() error { return debugserver.ServeHTTP(ctx, 6060) }) } - // if keypath and certpath are provided - secure := tlsCertPath != "" && tlsKeyPath != "" - // Open the browser when health check succeeds go a.pollServer(ctx, httpPort, enableUI && openBrowser, secure) @@ -578,6 +591,86 @@ func (a *App) emitStartEvent(ctx context.Context) error { return nil } +// initiateAuthFlow starts the OAuth2 PKCE flow to authenticate the user and get a rill access token. +func (a *App) initiateAuthFlow(httpPort int, secure bool) http.Handler { + scheme := "http" + if secure { + scheme = "https" + } + redirectURL := fmt.Sprintf("%s://localhost:%d/auth/callback", scheme, httpPort) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // generate random state + b := make([]byte, 32) + _, err := rand.Read(b) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate state: %s", err), http.StatusInternalServerError) + return + } + state := base64.URLEncoding.EncodeToString(b) + + // check the request for redirect query param, we will use this to redirect back to this after auth + origin := r.URL.Query().Get("redirect") + if origin == "" { + origin = "/" + } + + authenticator, err := pkce.NewAuthenticator(a.adminURL, redirectURL, database.AuthClientIDRillWebLocal, origin) + if err != nil { + http.Error(w, fmt.Sprintf("failed to generate pkce authenticator: %s", err), http.StatusInternalServerError) + return + } + a.pkceAuthenticators[state] = authenticator + authURL := authenticator.GetAuthURL(state) + http.Redirect(w, r, authURL, http.StatusFound) + }) +} + +// handleAuthCallback handles the OAuth2 PKCE callback to exchange the authorization code for a rill access token. +func (a *App) handleAuthCallback() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "missing code", http.StatusBadRequest) + return + } + state := r.URL.Query().Get("state") + if code == "" { + http.Error(w, "missing state", http.StatusBadRequest) + return + } + + authenticator, ok := a.pkceAuthenticators[state] + if !ok { + http.Error(w, "invalid state", http.StatusBadRequest) + return + } + + // remove authenticator from map + delete(a.pkceAuthenticators, state) + + if authenticator == nil { + http.Error(w, "failed to get authenticator", http.StatusInternalServerError) + return + } + + // Exchange the code for an access token + token, err := authenticator.ExchangeCodeForToken(code) + if err != nil { + http.Error(w, "failed to exchange code for token", http.StatusInternalServerError) + return + } + // save token and redirect back to url provided by caller when initiating auth flow + err = dotrill.SetAccessToken(token) + if err != nil { + http.Error(w, "failed to save access token", http.StatusInternalServerError) + return + } + a.ch.AdminTokenDefault = token + http.Redirect(w, r, authenticator.OriginURL, http.StatusFound) + }) +} + // IsProjectInit checks if the project is initialized by checking if rill.yaml exists in the project directory. // It doesn't use any runtime functions since we need the ability to check this before creating the instance. func IsProjectInit(projectPath string) bool { diff --git a/cli/pkg/pkce/authenticator.go b/cli/pkg/pkce/authenticator.go new file mode 100644 index 00000000000..24c85040dff --- /dev/null +++ b/cli/pkg/pkce/authenticator.go @@ -0,0 +1,177 @@ +package pkce + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/url" + "strings" + + "github.com/rilldata/rill/admin/pkg/oauth" +) + +const ( + // characters allowed in the PKCE code verifier + charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + codeChallengeMethod = "S256" +) + +type Authenticator struct { + client *http.Client + baseAuthURL string + redirectURL string + codeVerifier string + clientID string + OriginURL string +} + +func NewAuthenticator(baseAuthURL, redirectURL, clientID, origin string) (*Authenticator, error) { + if strings.Contains(baseAuthURL, "http://localhost:9090") { + baseAuthURL = "http://localhost:8080" + } + + // Generate a new code verifier + codeVerifier, err := generateCodeVerifier() + if err != nil { + return nil, err + } + + return &Authenticator{ + client: http.DefaultClient, + baseAuthURL: baseAuthURL, + redirectURL: redirectURL, + codeVerifier: codeVerifier, + clientID: clientID, + OriginURL: origin, + }, nil +} + +func (a *Authenticator) GetAuthURL(state string) string { + // Create the code challenge from the code verifier + codeChallenge := createCodeChallenge(a.codeVerifier) + // Create the authorization request URL + // Create a new URL instance from the authURL string + u, _ := url.Parse(a.baseAuthURL + "/auth/oauth/authorize") + + // Create a new query string from the URL's query + q := u.Query() + + // Set the client_id query parameter + q.Set("client_id", a.clientID) + // Set the redirect_uri query parameter + q.Set("redirect_uri", a.redirectURL) + // Set the response_type query parameter + q.Set("response_type", "code") + // Set the code_challenge query parameter + q.Set("code_challenge", codeChallenge) + // Set the code_challenge_method query parameter + q.Set("code_challenge_method", codeChallengeMethod) + // Set the state, will be used later to retrieve this authenticator + q.Set("state", state) + + // Encode the query string + u.RawQuery = q.Encode() + + // Return the URL as a string + return u.String() +} + +func (a *Authenticator) ExchangeCodeForToken(code string) (string, error) { + // Create the token request + req, err := tokenRequest(a.baseAuthURL, code, a.clientID, a.redirectURL, a.codeVerifier) + if err != nil { + return "", err + } + + // Send the token request + resp, err := a.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + // Check if the response is an error + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + tokenResponse := &oauth.TokenResponse{} + // Decode the response into the tokenResponse struct + if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil { + return "", err + } + + // Return the access token + return tokenResponse.AccessToken, nil +} + +func tokenRequest(baseAuthURL, code, clientID, redirectURI, codeVerifier string) (*http.Request, error) { + tokenURL := fmt.Sprintf("%s/auth/oauth/token", baseAuthURL) + payload := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "client_id": []string{clientID}, + "redirect_uri": []string{redirectURI}, + "code_verifier": []string{codeVerifier}, + } + req, err := http.NewRequest( + http.MethodPost, + tokenURL, + strings.NewReader(payload.Encode()), + ) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", oauth.FormMediaType) + req.Header.Set("Accept", oauth.JSONMediaType) + return req, nil +} + +// generateCodeVerifier creates a cryptographically secure random string +// which is between 43 and 128 characters long using the specified charset. +func generateCodeVerifier() (string, error) { + // Generate a random number between 0 and 85 to extend the length of the code verifier + r, err := rand.Int(rand.Reader, big.NewInt(86)) + if err != nil { + return "", err + } + // Define the length of the code verifier + // Here, we randomly choose a length between 43 and 128 characters + n := 43 + int(r.Int64()) + + // Create a byte slice of length n to store the characters of our code verifier + b := make([]byte, n) + // Temp slice to read random numbers into + temp := make([]byte, n) + if _, err := rand.Read(temp); err != nil { + return "", err + } + + // Assign a valid character from charset for each byte in b + for i := 0; i < n; i++ { + b[i] = charset[temp[i]%byte(len(charset))] + } + + return string(b), nil +} + +// createCodeChallenge takes a codeVerifier and returns its SHA256 hash +// encoded in Base64 URL encoding without padding, which is the code challenge. +func createCodeChallenge(codeVerifier string) string { + // Create a new SHA256 hash instance + hasher := sha256.New() + + // Write the codeVerifier to the hasher + hasher.Write([]byte(codeVerifier)) + + // Compute the SHA256 hash + hash := hasher.Sum(nil) + + // Base64 URL encode the hash + return base64.RawURLEncoding.EncodeToString(hash) +} diff --git a/cli/pkg/pkce/authenticator_test.go b/cli/pkg/pkce/authenticator_test.go new file mode 100644 index 00000000000..ff232661667 --- /dev/null +++ b/cli/pkg/pkce/authenticator_test.go @@ -0,0 +1,20 @@ +package pkce + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func Test_generateCodeVerifier(t *testing.T) { + for i := 0; i < 1000; i++ { + code, err := generateCodeVerifier() + require.NoError(t, err) + require.NotEmpty(t, code) + require.GreaterOrEqual(t, len(code), 43) + require.LessOrEqual(t, len(code), 128) + // only contains A-Z, a-z, 0-9, and the punctuation characters -._~ (hyphen, period, underscore, and tilde) + for _, c := range code { + require.Contains(t, charset, string(c)) + } + } +}