diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 92a09abf2..d9d23cecd 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -103,21 +103,21 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut switch t.Method.(type) { case *jwt.SigningMethodRSA: - key, err := f.findClientPublicJWK(oidcClient, t) + key, err := f.findClientPublicJWK(oidcClient, t, true) if err != nil { - return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve signing key from OAuth 2.0 Client because %s.", err)) + return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve RSA signing key from OAuth 2.0 Client because %s.", err)) } return key, nil case *jwt.SigningMethodECDSA: - key, err := f.findClientPublicJWK(oidcClient, t) + key, err := f.findClientPublicJWK(oidcClient, t, false) if err != nil { - return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve signing key from OAuth 2.0 Client because %s.", err)) + return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve ECDSA signing key from OAuth 2.0 Client because %s.", err)) } return key, nil case *jwt.SigningMethodRSAPSS: - key, err := f.findClientPublicJWK(oidcClient, t) + key, err := f.findClientPublicJWK(oidcClient, t, true) if err != nil { - return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve signing key from OAuth 2.0 Client because %s.", err)) + return nil, errors.WithStack(ErrInvalidRequestObject.WithHintf("Unable to retrieve RSA signing key from OAuth 2.0 Client because %s.", err)) } return key, nil default: diff --git a/client.go b/client.go index 61d3d17cf..9a4ee32e1 100644 --- a/client.go +++ b/client.go @@ -94,11 +94,12 @@ type DefaultClient struct { type DefaultOpenIDConnectClient struct { *DefaultClient - JSONWebKeysURI string `json:"jwks_uri"` - JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` - RequestURIs []string `json:"request_uris"` - RequestObjectSigningAlgorithm string `json:"request_object_signing_alg"` + JSONWebKeysURI string `json:"jwks_uri"` + JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + RequestURIs []string `json:"request_uris"` + RequestObjectSigningAlgorithm string `json:"request_object_signing_alg"` + TokenEndpointAuthSigningAlgorithm string `json:"token_endpoint_auth_signing_alg"` } func (c *DefaultClient) GetID() string { @@ -158,7 +159,11 @@ func (c *DefaultOpenIDConnectClient) GetJSONWebKeys() *jose.JSONWebKeySet { } func (c *DefaultOpenIDConnectClient) GetTokenEndpointAuthSigningAlgorithm() string { - return "RS256" + if c.TokenEndpointAuthSigningAlgorithm == "" { + return "RS256" + } else { + return c.TokenEndpointAuthSigningAlgorithm + } } func (c *DefaultOpenIDConnectClient) GetRequestObjectSigningAlgorithm() string { diff --git a/client_authentication.go b/client_authentication.go index 1516f8ed6..7bff15ac3 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -23,6 +23,7 @@ package fosite import ( "context" + "crypto/ecdsa" "crypto/rsa" "encoding/json" "fmt" @@ -37,9 +38,9 @@ import ( const clientAssertionJWTBearerType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" -func (f *Fosite) findClientPublicJWK(oidcClient OpenIDConnectClient, t *jwt.Token) (interface{}, error) { +func (f *Fosite) findClientPublicJWK(oidcClient OpenIDConnectClient, t *jwt.Token, expectsRSAKey bool) (interface{}, error) { if set := oidcClient.GetJSONWebKeys(); set != nil { - return findPublicKey(t, set) + return findPublicKey(t, set, expectsRSAKey) } if location := oidcClient.GetJSONWebKeysURI(); len(location) > 0 { @@ -48,7 +49,7 @@ func (f *Fosite) findClientPublicJWK(oidcClient OpenIDConnectClient, t *jwt.Toke return nil, err } - if key, err := findPublicKey(t, keys); err == nil { + if key, err := findPublicKey(t, keys, expectsRSAKey); err == nil { return key, nil } @@ -57,7 +58,7 @@ func (f *Fosite) findClientPublicJWK(oidcClient OpenIDConnectClient, t *jwt.Toke return nil, err } - return findPublicKey(t, keys) + return findPublicKey(t, keys, expectsRSAKey) } return nil, errors.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) @@ -115,11 +116,11 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u } if _, ok := t.Method.(*jwt.SigningMethodRSA); ok { - return f.findClientPublicJWK(oidcClient, t) + return f.findClientPublicJWK(oidcClient, t, true) } else if _, ok := t.Method.(*jwt.SigningMethodECDSA); ok { - return f.findClientPublicJWK(oidcClient, t) + return f.findClientPublicJWK(oidcClient, t, false) } else if _, ok := t.Method.(*jwt.SigningMethodRSAPSS); ok { - return f.findClientPublicJWK(oidcClient, t) + return f.findClientPublicJWK(oidcClient, t, true) } else if _, ok := t.Method.(*jwt.SigningMethodHMAC); ok { return nil, errors.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method \"client_secret_jwt\".")) } @@ -231,7 +232,7 @@ func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form u return client, nil } -func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet) (*rsa.PublicKey, error) { +func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (interface{}, error) { kid, ok := t.Header["kid"].(string) if !ok { return nil, errors.WithStack(ErrInvalidRequest.WithHint("The JSON Web Token must contain a kid header value but did not.")) @@ -246,12 +247,22 @@ func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet) (*rsa.PublicKey, error if key.Use != "sig" { continue } - if k, ok := key.Key.(*rsa.PublicKey); ok { - return k, nil + if expectsRSAKey { + if k, ok := key.Key.(*rsa.PublicKey); ok { + return k, nil + } + } else { + if k, ok := key.Key.(*ecdsa.PublicKey); ok { + return k, nil + } } } - return nil, errors.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use=\"sig\" for kid \"%s\" in JSON Web Key Set.", kid)) + if expectsRSAKey { + return nil, errors.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use=\"sig\" for kid \"%s\" in JSON Web Key Set.", kid)) + } else { + return nil, errors.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use=\"sig\" for kid \"%s\" in JSON Web Key Set.", kid)) + } } func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, clientSecret string, err error) { diff --git a/client_authentication_test.go b/client_authentication_test.go index fdd0fa3c3..f8710e9d9 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -23,6 +23,7 @@ package fosite_test import ( "context" + "crypto/ecdsa" "crypto/rsa" "encoding/base64" "encoding/json" @@ -44,7 +45,7 @@ import ( "github.com/ory/fosite/storage" ) -func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { +func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) token.Header["kid"] = kid tokenString, err := token.SignedString(key) @@ -52,6 +53,14 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK return tokenString } +func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string { + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = kid + tokenString, err := token.SignedString(key) + require.NoError(t, err) + return tokenString +} + func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) @@ -95,20 +104,31 @@ func TestAuthenticateClient(t *testing.T) { complexSecret, err := hasher.Hash(context.TODO(), []byte(complexSecretRaw)) require.NoError(t, err) - key := internal.MustRSAKey() - jwks := &jose.JSONWebKeySet{ + rsaKey := internal.MustRSAKey() + rsaJwks := &jose.JSONWebKeySet{ Keys: []jose.JSONWebKey{ { KeyID: "kid-foo", Use: "sig", - Key: &key.PublicKey, + Key: &rsaKey.PublicKey, + }, + }, + } + + ecdsaKey := internal.MustECDSAKey() + ecdsaJwks := &jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + { + KeyID: "kid-foo", + Use: "sig", + Key: &ecdsaKey.PublicKey, }, }, } var h http.HandlerFunc h = func(w http.ResponseWriter, r *http.Request) { - require.NoError(t, json.NewEncoder(w).Encode(jwks)) + require.NoError(t, json.NewEncoder(w).Encode(rsaJwks)) } ts := httptest.NewServer(h) defer ts.Close() @@ -237,153 +257,165 @@ func TestAuthenticateClient(t *testing.T) { expectErr: ErrInvalidRequest, }, { - d: "should pass with proper assertion when JWKs are set within the client and client_id is not set in the request", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + d: "should pass with proper RSA assertion when JWKs are set within the client and client_id is not set in the request", + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ + "sub": "bar", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, + r: new(http.Request), + }, + { + d: "should pass with proper ECDSA assertion when JWKs are set within the client and client_id is not set in the request", + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: ecdsaJwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlgorithm: "ES256"}, + form: url.Values{"client_assertion": {mustGenerateECDSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, ecdsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, { d: "should fail because token auth method is not private_key_jwt", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "client_secret_jwt"}, - form: url.Values{"client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "client_secret_jwt"}, + form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should pass with proper assertion when JWKs are set within the client and client_id is not set in the request (aud is array)", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": []string{"token-url-2", "token-url"}, - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, { d: "should fail because audience (array) does not match token url", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": []string{"token-url-1", "token-url-2"}, - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should pass with proper assertion when JWKs are set within the client", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, { d: "should fail because JWT algorithm is HS256", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateHSAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should fail because JWT algorithm is none", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateNoneAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should pass with proper assertion when JWKs URI is set", client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeysURI: ts.URL, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, { d: "should fail because client_assertion sub does not match client", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "not-bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should fail because client_assertion iss does not match client", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "not-bar", "jti": "12345", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should fail because client_assertion jti is not set", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "aud": "token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, { d: "should fail because client_assertion aud is not set", - client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: jwks, TokenEndpointAuthMethod: "private_key_jwt"}, - form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", "jti": "12345", "aud": "not-token-url", - }, key, "kid-foo")}, "client_assertion_type": []string{at}}, + }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), expectErr: ErrInvalidClient, }, @@ -444,7 +476,7 @@ func TestAuthenticateClientTwice(t *testing.T) { TokenURL: "token-url", } - formValues := url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateAssertion(t, jwt.MapClaims{ + formValues := url.Values{"client_id": []string{"bar"}, "client_assertion": {mustGenerateRSAAssertion(t, jwt.MapClaims{ "sub": "bar", "exp": time.Now().Add(time.Hour).Unix(), "iss": "bar", diff --git a/internal/key.go b/internal/key.go index be4113b6c..c80f4d621 100644 --- a/internal/key.go +++ b/internal/key.go @@ -22,6 +22,8 @@ package internal import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" ) @@ -34,3 +36,11 @@ func MustRSAKey() *rsa.PrivateKey { } return key } + +func MustECDSAKey() *ecdsa.PrivateKey { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + return key +} diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go index 134345f83..5ac82b13f 100644 --- a/token/jwt/jwt.go +++ b/token/jwt/jwt.go @@ -26,6 +26,7 @@ package jwt import ( "context" + "crypto/ecdsa" "crypto/rsa" "crypto/sha256" "fmt" @@ -126,6 +127,86 @@ func (j *RS256JWTStrategy) GetSigningMethodLength() int { return jwt.SigningMethodRS256.Hash.Size() } +// ES256JWTStrategy is responsible for generating and validating JWT challenges +type ES256JWTStrategy struct { + PrivateKey *ecdsa.PrivateKey +} + +// Generate generates a new authorize code or returns an error. set secret +func (j *ES256JWTStrategy) Generate(ctx context.Context, claims jwt.Claims, header Mapper) (string, string, error) { + if header == nil || claims == nil { + return "", "", errors.New("Either claims or header is nil.") + } + + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header = assign(token.Header, header.ToMap()) + + var sig, sstr string + var err error + if sstr, err = token.SigningString(); err != nil { + return "", "", errors.WithStack(err) + } + + if sig, err = token.Method.Sign(sstr, j.PrivateKey); err != nil { + return "", "", errors.WithStack(err) + } + + return fmt.Sprintf("%s.%s", sstr, sig), sig, nil +} + +// Validate validates a token and returns its signature or an error if the token is not valid. +func (j *ES256JWTStrategy) Validate(ctx context.Context, token string) (string, error) { + if _, err := j.Decode(ctx, token); err != nil { + return "", errors.WithStack(err) + } + + return j.GetSignature(ctx, token) +} + +// Decode will decode a JWT token +func (j *ES256JWTStrategy) Decode(ctx context.Context, token string) (*jwt.Token, error) { + // Parse the token. + parsedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { + return nil, errors.Errorf("Unexpected signing method: %v", t.Header["alg"]) + } + return &j.PrivateKey.PublicKey, nil + }) + + if err != nil { + return parsedToken, errors.WithStack(err) + } else if !parsedToken.Valid { + return parsedToken, errors.WithStack(fosite.ErrInactiveToken) + } + + return parsedToken, err +} + +// GetSignature will return the signature of a token +func (j *ES256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) { + split := strings.Split(token, ".") + if len(split) != 3 { + return "", errors.New("Header, body and signature must all be set") + } + return split[2], nil +} + +// Hash will return a given hash based on the byte input or an error upon fail +func (j *ES256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) { + // SigningMethodES256 + hash := sha256.New() + _, err := hash.Write(in) + if err != nil { + return []byte{}, errors.WithStack(err) + } + return hash.Sum([]byte{}), nil +} + +// GetSigningMethodLength will return the length of the signing method +func (j *ES256JWTStrategy) GetSigningMethodLength() int { + return jwt.SigningMethodES256.Hash.Size() +} + func assign(a, b map[string]interface{}) map[string]interface{} { for k, w := range b { if _, ok := a[k]; ok { diff --git a/token/jwt/jwt_test.go b/token/jwt/jwt_test.go index 5aed83896..b822a76a2 100644 --- a/token/jwt/jwt_test.go +++ b/token/jwt/jwt_test.go @@ -23,6 +23,7 @@ package jwt import ( "context" + "fmt" "strings" "testing" "time" @@ -40,13 +41,30 @@ var header = &Headers{ } func TestHash(t *testing.T) { - j := RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), + for k, tc := range []struct { + d string + strategy JWTStrategy + }{ + { + d: "RS256JWTStrategy", + strategy: &RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, + }, + { + d: "ES256JWTStrategy", + strategy: &ES256JWTStrategy{ + PrivateKey: internal.MustECDSAKey(), + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + in := []byte("foo") + out, err := tc.strategy.Hash(context.TODO(), in) + assert.NoError(t, err) + assert.NotEqual(t, in, out) + }) } - in := []byte("foo") - out, err := j.Hash(context.TODO(), in) - assert.NoError(t, err) - assert.NotEqual(t, in, out) } func TestAssign(t *testing.T) { @@ -77,72 +95,110 @@ func TestAssign(t *testing.T) { } func TestGenerateJWT(t *testing.T) { - claims := &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(time.Hour), - } - - j := RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), - } - - token, sig, err := j.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - - sig, err = j.Validate(context.TODO(), token) - require.NoError(t, err) - - sig, err = j.Validate(context.TODO(), token+"."+"0123456789") - require.Error(t, err) - - partToken := strings.Split(token, ".")[2] - - sig, err = j.Validate(context.TODO(), partToken) - require.Error(t, err) - - // Reset private key - j.PrivateKey = internal.MustRSAKey() - - // Lets validate the exp claim - claims = &JWTClaims{ - ExpiresAt: time.Now().UTC().Add(-time.Hour), - } - token, sig, err = j.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - //t.Logf("%s.%s", token, sig) - - sig, err = j.Validate(context.TODO(), token) - require.Error(t, err) - - // Lets validate the nbf claim - claims = &JWTClaims{ - NotBefore: time.Now().UTC().Add(time.Hour), + for k, tc := range []struct { + d string + strategy JWTStrategy + resetKey func(strategy JWTStrategy) + }{ + { + d: "RS256JWTStrategy", + strategy: &RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, + resetKey: func(strategy JWTStrategy) { + strategy.(*RS256JWTStrategy).PrivateKey = internal.MustRSAKey() + }, + }, + { + d: "ES256JWTStrategy", + strategy: &ES256JWTStrategy{ + PrivateKey: internal.MustECDSAKey(), + }, + resetKey: func(strategy JWTStrategy) { + strategy.(*ES256JWTStrategy).PrivateKey = internal.MustECDSAKey() + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + claims := &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + + token, sig, err := tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + + sig, err = tc.strategy.Validate(context.TODO(), token) + require.NoError(t, err) + + sig, err = tc.strategy.Validate(context.TODO(), token+"."+"0123456789") + require.Error(t, err) + + partToken := strings.Split(token, ".")[2] + + sig, err = tc.strategy.Validate(context.TODO(), partToken) + require.Error(t, err) + + // Reset private key + tc.resetKey(tc.strategy) + + // Lets validate the exp claim + claims = &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(-time.Hour), + } + token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + //t.Logf("%s.%s", token, sig) + + sig, err = tc.strategy.Validate(context.TODO(), token) + require.Error(t, err) + + // Lets validate the nbf claim + claims = &JWTClaims{ + NotBefore: time.Now().UTC().Add(time.Hour), + } + token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token) + //t.Logf("%s.%s", token, sig) + sig, err = tc.strategy.Validate(context.TODO(), token) + require.Error(t, err) + require.Empty(t, sig, "%s", err) + }) } - token, sig, err = j.Generate(context.TODO(), claims.ToMapClaims(), header) - require.NoError(t, err) - require.NotNil(t, token) - //t.Logf("%s.%s", token, sig) - sig, err = j.Validate(context.TODO(), token) - require.Error(t, err) - require.Empty(t, sig, "%s", err) } func TestValidateSignatureRejectsJWT(t *testing.T) { - var err error - j := RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), - } - - for k, c := range []string{ - "", - " ", - "foo.bar", - "foo.", - ".foo", + for k, tc := range []struct { + d string + strategy JWTStrategy + }{ + { + d: "RS256JWTStrategy", + strategy: &RS256JWTStrategy{ + PrivateKey: internal.MustRSAKey(), + }, + }, + { + d: "ES256JWTStrategy", + strategy: &ES256JWTStrategy{ + PrivateKey: internal.MustECDSAKey(), + }, + }, } { - _, err = j.Validate(context.TODO(), c) - assert.Error(t, err) - t.Logf("Passed test case %d", k) + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + for k, c := range []string{ + "", + " ", + "foo.bar", + "foo.", + ".foo", + } { + _, err := tc.strategy.Validate(context.TODO(), c) + assert.Error(t, err) + t.Logf("Passed test case %d", k) + } + }) } }