Skip to content

Commit

Permalink
Made CORS header optional (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
oxisto authored Nov 18, 2022
1 parent e54f9be commit e317081
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
1 change: 1 addition & 0 deletions cmd/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func main() {
oauth2.WithClient("client", *clientSecret, *redirectURI),
oauth2.WithClient("public", "", *redirectURI),
login.WithLoginPage(login.WithUser("admin", *userPassword)),
oauth2.WithAllowedOrigins("*"),
)
srv.BaseContext = ctx

Expand Down
13 changes: 12 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ type AuthorizationServer struct {

// our codes and their expiry time and challenge
codes map[string]*codeInfo

// the allowed CORS origin
allowedOrigin string
}

type AuthorizationServerOption func(srv *AuthorizationServer)
Expand Down Expand Up @@ -80,6 +83,12 @@ func WithSigningKeysFunc(f signingKeysFunc) AuthorizationServerOption {
}
}

func WithAllowedOrigins(origin string) AuthorizationServerOption {
return func(srv *AuthorizationServer) {
srv.allowedOrigin = origin
}
}

func NewServer(addr string, opts ...AuthorizationServerOption) *AuthorizationServer {
mux := http.NewServeMux()

Expand Down Expand Up @@ -181,7 +190,9 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r
client *Client
)

w.Header().Add("Access-Control-Allow-Origin", "*")
if srv.allowedOrigin != "" {
w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin)
}

// Retrieve the client
client, err = srv.retrieveClient(r, true)
Expand Down
30 changes: 24 additions & 6 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,10 @@ func TestAuthorizationServer_doClientCredentialsFlow(t *testing.T) {

func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
type fields struct {
clients []*Client
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
clients []*Client
signingKeys map[int]*ecdsa.PrivateKey
codes map[string]*codeInfo
allowedOrigin string
}
type args struct {
r *http.Request
Expand All @@ -434,6 +435,22 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
wantCode int
wantBody string
}{
{
name: "cors header",
fields: fields{
allowedOrigin: "*",
},
args: args{
r: &http.Request{
Method: "POST",
Header: http.Header{
http.CanonicalHeaderKey("Authorization"): []string{"notvalid"},
},
},
},
wantCode: http.StatusUnauthorized,
wantBody: `{"error": "invalid_client"}`,
},
{
name: "missing or invalid authorization",
args: args{
Expand Down Expand Up @@ -541,9 +558,10 @@ func TestAuthorizationServer_doAuthorizationCodeFlow(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
srv := &AuthorizationServer{
clients: tt.fields.clients,
signingKeys: tt.fields.signingKeys,
codes: tt.fields.codes,
clients: tt.fields.clients,
signingKeys: tt.fields.signingKeys,
codes: tt.fields.codes,
allowedOrigin: tt.fields.allowedOrigin,
}
rr := httptest.NewRecorder()
srv.doAuthorizationCodeFlow(rr, tt.args.r)
Expand Down

0 comments on commit e317081

Please sign in to comment.