Skip to content

Commit

Permalink
Implementation of PKCE (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto authored Feb 28, 2022
1 parent 482fc19 commit 7433da3
Show file tree
Hide file tree
Showing 9 changed files with 351 additions and 59 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ A login form is available on http://localhost:8008/login.

## (To be) Implemented Standards

* [RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749). The basics
* [RFC 6750](https://datatracker.ietf.org/doc/html/rfc6750). We are exclusively using JWTs as bearer tokens
* [RFC 7517](https://datatracker.ietf.org/doc/html/rfc7517). JSON Web Key
* [RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749). The OAuth 2.0 Authorization Framework
* [RFC 6750](https://datatracker.ietf.org/doc/html/rfc6750). The OAuth 2.0 Authorization Framework: Bearer Token Usage
* [RFC 7517](https://datatracker.ietf.org/doc/html/rfc7517). JSON Web Key (JWK)
* [RFC 7636](https://datatracker.ietf.org/doc/html/rfc7636). Proof Key for Code Exchange by OAuth Public Clients
6 changes: 6 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ type Client struct {

RedirectURI string
}

// Public returns true if this client is a public client. We enforce
// certain additional security measures for public clients, for example PKCE.
func (c *Client) Public() bool {
return c.ClientSecret == ""
}
9 changes: 6 additions & 3 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
)

var port = flag.Int("port", 8080, "the default port")
var redirectURI = flag.String("redirect-uri", "http://localhost", "the default redirect URI")
var srv *oauth2.AuthorizationServer
var ctx func(net.Listener) context.Context = nil

Expand All @@ -22,16 +23,18 @@ func main() {
clientPassword := oauth2.GenerateSecret()

log.Printf(`Creating new user "admin" with password %s`, userPassword)
log.Printf(`Creating new client "client" with password %s`, clientPassword)
log.Printf(`Creating new confidential client "client" with password %s`, clientPassword)
log.Printf(`Creating new public client "public"`)

srv = oauth2.NewServer(
fmt.Sprintf(":%d", *port),
oauth2.WithClient("client", clientPassword, ""),
oauth2.WithClient("client", clientPassword, *redirectURI),
oauth2.WithClient("public", "", *redirectURI),
login.WithLoginPage(login.WithUser("admin", userPassword)),
)
srv.BaseContext = ctx

log.Printf("Creating new OAuth 2.0 server on :%d", *port)
log.Printf("Starting new OAuth 2.0 server on :%d", *port)

log.Fatal(srv.ListenAndServe())
}
3 changes: 3 additions & 0 deletions compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ type TokenSource = oauth2.TokenSource

// Transport is a type alias for https://pkg.go.dev/golang.org/x/oauth2#Transport.
type Transport = oauth2.Transport

// SetAuthURLParam is a function alias for https://pkg.go.dev/golang.org/x/oauth2#SetAuthURLParam.
var SetAuthURLParam = oauth2.SetAuthURLParam
130 changes: 127 additions & 3 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package oauth2_test

import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
"log"
"net"
Expand Down Expand Up @@ -55,7 +57,120 @@ func TestIntegration(t *testing.T) {
log.Printf("JWT: %+v", jwtoken)
}

func TestThreeLeggedFlow(t *testing.T) {
func TestThreeLeggedFlowPublicClient(t *testing.T) {
var (
res *http.Response
req *http.Request
client *http.Client
form url.Values
session *http.Cookie
token *oauth2.Token
code string
challenge string
verifier string
)

srv := oauth2.NewServer(":0",
oauth2.WithClient("public", "", "/test"),
login.WithLoginPage(login.WithUser("admin", "admin")),
)

ln, err := net.Listen("tcp", srv.Addr)
if err != nil {
t.Errorf("Error while listening key: %v", err)
}

port := ln.Addr().(*net.TCPAddr).Port

go srv.Serve(ln)
defer srv.Close()

config := oauth2.Config{
ClientID: "public",
ClientSecret: "",
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("http://localhost:%d/authorize", port),
TokenURL: fmt.Sprintf("http://localhost:%d/token", port),
},
RedirectURL: "/test",
}

// create a challenge and verifier
verifier = "012345678901234567890123456789012345678901234567890123456789"
challenge = base64.URLEncoding.EncodeToString(sha256.New().Sum([]byte(verifier)))

// Let's pretend to be a browser
res, err = http.Get(config.AuthCodeURL("some-state",
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
))
if err != nil {
t.Errorf("Error while POST /authorize: %v", err)
}

// We are interested in two things
// - The session ID (or the cookie)
// - The CSRF token
for _, c := range res.Cookies() {
if c.Name == "id" {
session = c
break
}
}

if session == nil {
t.Errorf("Error session is nil")
}

// Parse the HTML body to look for the csrf_token
root, _ := html.Parse(res.Body)

form = url.Values{}
walker := func(node *html.Node) {
if node.Type == html.ElementNode &&
node.Data == "input" &&
len(node.Attr) == 3 {
form.Add(node.Attr[1].Val, node.Attr[2].Val)
}
}

traverse(root, walker)

form.Add("username", "admin")
form.Add("password", "admin")

req, _ = http.NewRequest("POST", fmt.Sprintf("http://localhost:%d/login", port), strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(session)

// Let's POST our login
client = &http.Client{}
res, err = client.Do(req)
if err != nil {
t.Errorf("Error while POST /login: %v", err)
}

// Extract the code from the response
code = res.Request.URL.Query().Get("code")

token, err = config.Exchange(context.Background(),
code,
oauth2.SetAuthURLParam("code_verifier", verifier),
)
if err != nil {
t.Errorf("Error while Exchange: %v", err)
}

if token.AccessToken == "" {
t.Error("Access token is empty", err)
}

if token.RefreshToken == "" {
t.Error("Access token is empty", err)
}
}

func TestThreeLeggedFlowConfidentialClient(t *testing.T) {
var (
res *http.Response
req *http.Request
Expand Down Expand Up @@ -92,7 +207,10 @@ func TestThreeLeggedFlow(t *testing.T) {
}

// Let's pretend to be a browser
res, _ = http.Get(config.AuthCodeURL("some-state"))
res, err = http.Get(config.AuthCodeURL("some-state"))
if err != nil {
t.Errorf("Error while POST /authorize: %v", err)
}

// We are interested in two things
// - The session ID (or the cookie)
Expand All @@ -104,6 +222,10 @@ func TestThreeLeggedFlow(t *testing.T) {
}
}

if session == nil {
t.Errorf("Error session is nil")
}

// Parse the HTML body to look for the csrf_token
root, _ := html.Parse(res.Body)

Expand Down Expand Up @@ -135,7 +257,9 @@ func TestThreeLeggedFlow(t *testing.T) {
// Extract the code from the response
code = res.Request.URL.Query().Get("code")

token, err = config.Exchange(context.Background(), code)
token, err = config.Exchange(context.Background(),
code,
)
if err != nil {
t.Errorf("Error while Exchange: %v", err)
}
Expand Down
18 changes: 16 additions & 2 deletions login/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ func (h *handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
client *oauth2.Client
redirectURI string
state string
challenge string
method string
err error
query url.Values
session *session
Expand All @@ -34,7 +36,19 @@ func (h *handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
}

if query.Get("response_type") != "code" {
oauth2.RedirectError(w, r, redirectURI, "invalid_request")
oauth2.RedirectError(w, r, redirectURI, "invalid_request", "")
return
}

challenge = query.Get("code_challenge")
if client.Public() && challenge == "" {
oauth2.RedirectError(w, r, redirectURI, "invalid_request", "Code challenge is required")
return
}

method = query.Get("code_challenge_method")
if client.Public() && method != "S256" {
oauth2.RedirectError(w, r, redirectURI, "invalid_request", "Only transform algorithm S265 is supported")
return
}

Expand All @@ -51,7 +65,7 @@ func (h *handler) handleAuthorize(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, fmt.Sprintf("/login?%s", params.Encode()), http.StatusFound)
} else {
var params = url.Values{}
params.Add("code", h.srv.IssueCode())
params.Add("code", h.srv.IssueCode(challenge))
params.Add("state", state)

http.Redirect(w, r, fmt.Sprintf("%s?%s", redirectURI, params.Encode()), http.StatusFound)
Expand Down
40 changes: 37 additions & 3 deletions login/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,40 @@ func Test_handler_handleAuthorize(t *testing.T) {
"Location": []string{"/test\\?error=invalid_request"},
},
},
{
name: "missing code challenge for public client",
fields: fields{
sessions: map[string]*session{},
users: []*User{},
log: log.Default(),
pwh: bcryptHasher{},
srv: oauth2.NewServer(":0", oauth2.WithClient("public", "", "/test")),
},
args: args{
r: httptest.NewRequest("GET", "/authorize?client_id=public&redirect_uri=/test&response_type=code", nil),
},
wantCode: http.StatusFound,
wantHeaderRegexp: http.Header{
"Location": []string{"/test\\?error=invalid_request&error_description=Code\\+challenge\\+is\\+required"},
},
},
{
name: "invalid code challenge method for public client",
fields: fields{
sessions: map[string]*session{},
users: []*User{},
log: log.Default(),
pwh: bcryptHasher{},
srv: oauth2.NewServer(":0", oauth2.WithClient("public", "", "/test")),
},
args: args{
r: httptest.NewRequest("GET", "/authorize?client_id=public&redirect_uri=/test&response_type=code&code_challenge=0123456789&code_challenge_method=WHAT", nil),
},
wantCode: http.StatusFound,
wantHeaderRegexp: http.Header{
"Location": []string{"/test\\?error=invalid_request&error_description=Only\\+transform\\+algorithm\\+S265\\+is\\+supported"},
},
},
{
name: "valid request, no session",
fields: fields{
Expand All @@ -128,12 +162,12 @@ func Test_handler_handleAuthorize(t *testing.T) {
srv: oauth2.NewServer(":0", oauth2.WithClient("client", "secret", "/test")),
},
args: args{
r: httptest.NewRequest("GET", "/authorize?client_id=client&redirect_uri=/test&response_type=code", nil),
r: httptest.NewRequest("GET", "/authorize?client_id=client&redirect_uri=/test&response_type=code&code_challenge=0123456789&code_challenge_method=S256", nil),
},
wantCode: http.StatusFound,
wantHeaderRegexp: http.Header{
// Should redirect to login page but with this authorize endpoint as return URL
"Location": []string{"/login\\?return_url=%2Fauthorize%3Fclient_id%3Dclient%26redirect_uri%3D%2Ftest%26response_type%3Dcode"},
"Location": []string{"/login\\?return_url=%2Fauthorize%3Fclient_id%3Dclient%26redirect_uri%3D%2Ftest%26response_type%3Dcode%26code_challenge%3D0123456789%26code_challenge_method%3DS256"},
},
},
{
Expand All @@ -156,7 +190,7 @@ func Test_handler_handleAuthorize(t *testing.T) {
},
args: args{
r: func() *http.Request {
r := httptest.NewRequest("GET", "/authorize?client_id=client&redirect_uri=/test&response_type=code", nil)
r := httptest.NewRequest("GET", "/authorize?client_id=client&redirect_uri=/test&response_type=code&code_challenge=0123456789&code_challenge_method=S256", nil)
r.AddCookie(&http.Cookie{
Name: "id",
Value: "mysession",
Expand Down
Loading

0 comments on commit 7433da3

Please sign in to comment.