Skip to content

Commit

Permalink
PKCE authentication support for the local UI (#4843)
Browse files Browse the repository at this point in the history
* runtime local auth flow

* remove print stmt

* review comments

* lint

* review comments
  • Loading branch information
pjain1 authored May 9, 2024
1 parent 1db9449 commit 38796b4
Show file tree
Hide file tree
Showing 14 changed files with 665 additions and 45 deletions.
26 changes: 23 additions & 3 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ type DB interface {
UpdateDeviceAuthCode(ctx context.Context, id, userID string, state DeviceAuthCodeState) error
DeleteExpiredDeviceAuthCodes(ctx context.Context, retention time.Duration) error

FindAuthorizationCode(ctx context.Context, code string) (*AuthorizationCode, error)
InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*AuthorizationCode, error)
DeleteAuthorizationCode(ctx context.Context, code string) error
DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error

FindOrganizationRole(ctx context.Context, name string) (*OrganizationRole, error)
FindProjectRole(ctx context.Context, name string) (*ProjectRole, error)
ResolveOrganizationRolesForUser(ctx context.Context, userID, orgID string) ([]*OrganizationRole, error)
Expand Down Expand Up @@ -512,9 +517,10 @@ type AuthClient struct {

// Hard-coded auth client IDs (created in the migrations).
const (
AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001"
AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002"
AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003"
AuthClientIDRillWeb = "12345678-0000-0000-0000-000000000001"
AuthClientIDRillCLI = "12345678-0000-0000-0000-000000000002"
AuthClientIDRillSupport = "12345678-0000-0000-0000-000000000003"
AuthClientIDRillWebLocal = "12345678-0000-0000-0000-000000000004"
)

// DeviceAuthCodeState is an enum representing the approval state of a DeviceAuthCode
Expand All @@ -540,6 +546,20 @@ type DeviceAuthCode struct {
UpdatedOn time.Time `db:"updated_on"`
}

// AuthorizationCode represents an authorization code used for OAuth2 PKCE auth flow.
type AuthorizationCode struct {
ID string `db:"id"`
Code string `db:"code"`
UserID string `db:"user_id"`
ClientID string `db:"client_id"`
RedirectURI string `db:"redirect_uri"`
CodeChallenge string `db:"code_challenge"`
CodeChallengeMethod string `db:"code_challenge_method"`
Expiration time.Time `db:"expires_on"`
CreatedOn time.Time `db:"created_on"`
UpdatedOn time.Time `db:"updated_on"`
}

// Constants for known role names (created in migrations).
const (
OrganizationRoleNameAdmin = "admin"
Expand Down
20 changes: 20 additions & 0 deletions admin/database/postgres/migrations/0028.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- Hard-coded first-party auth clients
INSERT INTO auth_clients (id, display_name)
VALUES ('12345678-0000-0000-0000-000000000004', 'Rill Localhost');

-- Table for storing authorization codes for PKCE auth flow
CREATE TABLE authorization_codes (
id UUID DEFAULT uuid_generate_v4() PRIMARY KEY,
code TEXT NOT NULL,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
client_id UUID NOT NULL REFERENCES auth_clients(id) ON DELETE CASCADE,
redirect_uri TEXT NOT NULL,
code_challenge TEXT NOT NULL,
code_challenge_method TEXT NOT NULL,
expires_on TIMESTAMP NOT NULL,
created_on TIMESTAMPTZ DEFAULT now() NOT NULL,
updated_on TIMESTAMPTZ DEFAULT now() NOT NULL
);

-- create index on code column
CREATE UNIQUE INDEX authorization_codes_code_idx ON authorization_codes(code);
30 changes: 30 additions & 0 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,36 @@ func (c *connection) DeleteExpiredDeviceAuthCodes(ctx context.Context, retention
return parseErr("device auth code", err)
}

func (c *connection) FindAuthorizationCode(ctx context.Context, code string) (*database.AuthorizationCode, error) {
authCode := &database.AuthorizationCode{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM authorization_codes WHERE code = $1", code).StructScan(authCode)
if err != nil {
return nil, parseErr("authorization code", err)
}
return authCode, nil
}

func (c *connection) InsertAuthorizationCode(ctx context.Context, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod string, expiration time.Time) (*database.AuthorizationCode, error) {
res := &database.AuthorizationCode{}
err := c.getDB(ctx).QueryRowxContext(ctx,
`INSERT INTO authorization_codes (code, user_id, client_id, redirect_uri, code_challenge, code_challenge_method, expires_on)
VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *`, code, userID, clientID, redirectURI, codeChallenge, codeChallengeMethod, expiration).StructScan(res)
if err != nil {
return nil, parseErr("authorization code", err)
}
return res, nil
}

func (c *connection) DeleteAuthorizationCode(ctx context.Context, code string) error {
res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE code=$1", code)
return checkDeleteRow("authorization code", res, err)
}

func (c *connection) DeleteExpiredAuthorizationCodes(ctx context.Context, retention time.Duration) error {
_, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM authorization_codes WHERE expires_on + $1 < now()", retention)
return parseErr("authorization code", err)
}

func (c *connection) FindOrganizationRole(ctx context.Context, name string) (*database.OrganizationRole, error) {
role := &database.OrganizationRole{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM org_roles WHERE lower(name)=lower($1)", name).StructScan(role)
Expand Down
14 changes: 14 additions & 0 deletions admin/pkg/oauth/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package oauth

const (
FormMediaType = "application/x-www-form-urlencoded"
JSONMediaType = "application/json"
)

// TokenResponse contains the information returned after fetching an access token from the OAuth server.
type TokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in,string"`
TokenType string `json:"token_type"`
UserID string `json:"user_id"`
}
25 changes: 5 additions & 20 deletions admin/server/auth/device_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (

"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/admin/pkg/oauth"
"github.com/rilldata/rill/admin/pkg/urlutil"
"github.com/rilldata/rill/cli/pkg/deviceauth"
)

const deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code"
Expand Down Expand Up @@ -175,23 +175,8 @@ func (a *Authenticator) handleUserCodeConfirmation(w http.ResponseWriter, r *htt
}
}

// getAccessToken verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}
// getAccessTokenForDeviceCode verifies the device code and returns an access token if the request is approved
func (a *Authenticator) getAccessTokenForDeviceCode(w http.ResponseWriter, r *http.Request, values url.Values) {
deviceCode := values.Get("device_code")
if deviceCode == "" {
http.Error(w, "device_code is required", http.StatusBadRequest)
Expand Down Expand Up @@ -253,10 +238,10 @@ func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
return
}

resp := deviceauth.OAuthTokenResponse{
resp := oauth.TokenResponse{
AccessToken: authToken.Token().String(),
TokenType: "Bearer",
ExpiresIn: time.UnixMilli(0).Unix(), // never expires
ExpiresIn: 0, // never expires
UserID: *authCode.UserID,
}
respBytes, err := json.Marshal(resp)
Expand Down
83 changes: 82 additions & 1 deletion admin/server/auth/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -56,7 +57,8 @@ func (a *Authenticator) RegisterEndpoints(mux *http.ServeMux, limiter ratelimit.
observability.MuxHandle(inner, "/auth/logout", middleware.Check(checkLimit("/auth/logout"), http.HandlerFunc(a.authLogout)))
observability.MuxHandle(inner, "/auth/logout/callback", middleware.Check(checkLimit("/auth/logout/callback"), http.HandlerFunc(a.authLogoutCallback)))
observability.MuxHandle(inner, "/auth/oauth/device_authorization", middleware.Check(checkLimit("/auth/oauth/device_authorization"), http.HandlerFunc(a.handleDeviceCodeRequest)))
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/device", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/device"), http.HandlerFunc(a.handleUserCodeConfirmation)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/authorize", a.HTTPMiddleware(middleware.Check(checkLimit("/auth/oauth/authorize"), http.HandlerFunc(a.handleAuthorizeRequest)))) // NOTE: Uses auth middleware
observability.MuxHandle(inner, "/auth/oauth/token", middleware.Check(checkLimit("/auth/oauth/token"), http.HandlerFunc(a.getAccessToken)))
mux.Handle("/auth/", observability.Middleware("admin", a.logger, inner))
}
Expand Down Expand Up @@ -355,3 +357,82 @@ func (a *Authenticator) authLogoutCallback(w http.ResponseWriter, r *http.Reques
// Redirect to UI (usually)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}

// handleAuthorizeRequest handles the incoming OAuth2 Authorization request, if he user is not logged redirect to login, currently only PKCE based authorization code flow is supported
func (a *Authenticator) handleAuthorizeRequest(w http.ResponseWriter, r *http.Request) {
claims := GetClaims(r.Context())
if claims == nil {
internalServerError(w, fmt.Errorf("did not find any claims, %w", errors.New("server error")))
return
}
if claims.OwnerType() == OwnerTypeAnon {
// not logged in, redirect to login
// after login redirect back to same path so encode the current URL as a redirect parameter
encodedURL := url.QueryEscape(r.URL.String())
http.Redirect(w, r, "/auth/login?redirect="+encodedURL, http.StatusTemporaryRedirect)
}
if claims.OwnerType() != OwnerTypeUser {
http.Error(w, "only users can be authorized", http.StatusBadRequest)
return
}
userID := claims.OwnerID()

// Extract necessary details from the query parameters
clientID := r.URL.Query().Get("client_id")
redirectURI := r.URL.Query().Get("redirect_uri")
responseType := r.URL.Query().Get("response_type")

if clientID == "" || redirectURI == "" || responseType == "" {
http.Error(w, "Missing required parameters - client_id or redirect_uri or response_type", http.StatusBadRequest)
return
}

codeChallenge := r.URL.Query().Get("code_challenge")
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")

if codeChallenge != "" {
if codeChallengeMethod == "" {
http.Error(w, "Missing code challenge method", http.StatusBadRequest)
return
}
if responseType != "code" {
http.Error(w, "Invalid response type", http.StatusBadRequest)
return
}
a.handlePKCE(w, r, clientID, userID, codeChallenge, codeChallengeMethod, redirectURI)
} else {
http.Error(w, "only PKCE based authorization code flow is supported", http.StatusBadRequest)
return
}
}

// getAccessToken depending on the grant_type either verifies the device code and returns an access token if the request is approved or exchanges the authorization code for an access token
func (a *Authenticator) getAccessToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "expected a POST request", http.StatusBadRequest)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
internalServerError(w, fmt.Errorf("failed to read request body: %w", err))
return
}
bodyStr := string(body)
values, err := url.ParseQuery(bodyStr)
if err != nil {
internalServerError(w, fmt.Errorf("failed to parse query: %w", err))
return
}

grantType := values.Get("grant_type")
if !(grantType == deviceCodeGrantType || grantType == authorizationCodeGrantType) {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}

if grantType == deviceCodeGrantType {
a.getAccessTokenForDeviceCode(w, r, values)
} else {
a.getAccessTokenForAuthorizationCode(w, r, values)
}
}
Loading

0 comments on commit 38796b4

Please sign in to comment.