Skip to content

Commit

Permalink
Adding CORS header everywhere (#58)
Browse files Browse the repository at this point in the history
Previously, it was only added to some calls, making some things not work.
  • Loading branch information
oxisto authored Mar 14, 2023
1 parent 87f4554 commit f3a2b96
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ func (srv *AuthorizationServer) handleJWKS(w http.ResponseWriter, r *http.Reques
})
}

writeJSON(w, keySet)
srv.writeJSON(w, keySet)
}
2 changes: 1 addition & 1 deletion metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ func (srv *AuthorizationServer) handleMetadata(w http.ResponseWriter, r *http.Re
return
}

writeJSON(w, srv.metadata)
srv.writeJSON(w, srv.metadata)
}
24 changes: 14 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (srv *AuthorizationServer) doClientCredentialsFlow(w http.ResponseWriter, r
return
}

writeToken(w, token)
srv.writeToken(w, token)
}

// doAuthorizationCodeFlow implements the Authorization Code Grant
Expand All @@ -213,10 +213,6 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r
client *Client
)

if srv.allowedOrigin != "" {
w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin)
}

// Retrieve the client
client, err = srv.retrieveClient(r, true)
if err != nil {
Expand Down Expand Up @@ -245,7 +241,7 @@ func (srv *AuthorizationServer) doAuthorizationCodeFlow(w http.ResponseWriter, r
return
}

writeToken(w, token)
srv.writeToken(w, token)
}

// doRefreshTokenFlow implements refreshing an access token.
Expand Down Expand Up @@ -304,7 +300,7 @@ issue:
return
}

writeToken(w, token)
srv.writeToken(w, token)
}

// GetClient returns the client for the given ID or ErrClientNotFound.
Expand Down Expand Up @@ -444,6 +440,12 @@ func (srv *AuthorizationServer) GenerateToken(clientID string, signingKeyID int,
return
}

func (srv *AuthorizationServer) cors(w http.ResponseWriter) {
if srv.allowedOrigin != "" {
w.Header().Add("Access-Control-Allow-Origin", srv.allowedOrigin)
}
}

func Error(w http.ResponseWriter, error string, statusCode int) {
w.Header().Set("Content-Type", "application/json")

Expand All @@ -463,7 +465,7 @@ func RedirectError(w http.ResponseWriter,
http.Redirect(w, r, fmt.Sprintf("%s?%s", redirectURI, params.Encode()), http.StatusFound)
}

func writeToken(w http.ResponseWriter, token *oauth2.Token) {
func (srv *AuthorizationServer) writeToken(w http.ResponseWriter, token *oauth2.Token) {
// We need to transform this into our own struct, otherwise
// the expiry will be translated into a string representation,
// while it should be represented as seconds.
Expand All @@ -479,12 +481,14 @@ func writeToken(w http.ResponseWriter, token *oauth2.Token) {
Expiry: int(time.Until(token.Expiry).Seconds()),
}

writeJSON(w, s)
srv.writeJSON(w, s)
}

func writeJSON(w http.ResponseWriter, value interface{}) {
func (srv *AuthorizationServer) writeJSON(w http.ResponseWriter, value interface{}) {
w.Header().Set("Content-Type", "application/json")

srv.cors(w)

if err := json.NewEncoder(w).Encode(value); err != nil {
Error(w, "could not encode JSON", http.StatusInternalServerError)
return
Expand Down
6 changes: 4 additions & 2 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ func TestAuthorizationServer_retrieveClient(t *testing.T) {
}
}

func Test_writeJSON(t *testing.T) {
func TestAuthorizationServer_writeJSON(t *testing.T) {
type fields struct{}
type args struct {
w http.ResponseWriter
value interface{}
Expand All @@ -239,7 +240,8 @@ func Test_writeJSON(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
writeJSON(tt.args.w, tt.args.value)
srv := &AuthorizationServer{}
srv.writeJSON(tt.args.w, tt.args.value)

var rr *httptest.ResponseRecorder
switch v := tt.args.w.(type) {
Expand Down

0 comments on commit f3a2b96

Please sign in to comment.