diff --git a/cmd/server/server.go b/cmd/server/server.go index d5c1014..9a54da5 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -40,6 +40,7 @@ func main() { oauth2.WithClient("client", *clientSecret, *redirectURI), oauth2.WithClient("public", "", *redirectURI), login.WithLoginPage(login.WithUser("admin", *userPassword)), + oauth2.WithAllowedOrigins("*"), ) srv.BaseContext = ctx diff --git a/server.go b/server.go index 9f1a8e4..7afa8e5 100644 --- a/server.go +++ b/server.go @@ -49,6 +49,9 @@ type AuthorizationServer struct { // our codes and their expiry time and challenge codes map[string]*codeInfo + + // the allowed CORS origin + allowedOrigin string } type AuthorizationServerOption func(srv *AuthorizationServer) @@ -80,6 +83,12 @@ func WithSigningKeysFunc(f signingKeysFunc) AuthorizationServerOption { } } +func WithAllowedOrigins(origin string) AuthorizationServerOption { + return func(srv *AuthorizationServer) { + srv.allowedOrigin = origin + } +} + func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationServer { mux := http.NewServeMux() @@ -181,7 +190,9 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r client *Client ) - w.Header().Add("Access-Control-Allow-Origin", "*") + if srv.allowedOrigin != "" { + w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin) + } // Retrieve the client client, err = srv.retrieveClient(r, true) diff --git a/server_test.go b/server_test.go index b68eda6..5191c44 100644 --- a/server_test.go +++ b/server_test.go @@ -420,9 +420,10 @@ func TestAuthorizationServer_doClientCredentialsFlow(t *testing.T) { func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) { type fields struct { - clients []*Client - signingKeys map[int]*ecdsa.PrivateKey - codes map[string]*codeInfo + clients []*Client + signingKeys map[int]*ecdsa.PrivateKey + codes map[string]*codeInfo + allowedOrigin string } type args struct { r *http.Request @@ -434,6 +435,22 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) { wantCode int wantBody string }{ + { + name: "cors header", + fields: fields{ + allowedOrigin: "*", + }, + args: args{ + r: &http.Request{ + Method: "POST", + Header: http.Header{ + http.CanonicalHeaderKey("Authorization"): []string{"notvalid"}, + }, + }, + }, + wantCode: http.StatusUnauthorized, + wantBody: `{"error": "invalid_client"}`, + }, { name: "missing or invalid authorization", args: args{ @@ -541,9 +558,10 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { srv := &AuthorizationServer{ - clients: tt.fields.clients, - signingKeys: tt.fields.signingKeys, - codes: tt.fields.codes, + clients: tt.fields.clients, + signingKeys: tt.fields.signingKeys, + codes: tt.fields.codes, + allowedOrigin: tt.fields.allowedOrigin, } rr := httptest.NewRecorder() srv.doAuthorizationCodeFlow(rr, tt.args.r)