diff --git a/controller/stories/stories.go b/controller/stories/stories.go index cc3cd8b..214c66a 100644 --- a/controller/stories/stories.go +++ b/controller/stories/stories.go @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error { storyIDStr := chi.URLParam(r, "storyID") storyID, err := strconv.Atoi(storyIDStr) if err != nil { - http.Error(w, "Invalid storyID", http.StatusBadRequest) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Invalid storyID: %v", err), + } } // Get DB instance @@ -66,7 +67,9 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error { e, ok := err.(*json.UnmarshalTypeError) if !ok { logrus.Error(err) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Bad JSON parsing: %v", err), + } } // TODO: Investigate if we should use errors.Wrap instead diff --git a/controller/users/users.go b/controller/users/users.go index e20cd77..46806c1 100644 --- a/controller/users/users.go +++ b/controller/users/users.go @@ -40,8 +40,9 @@ func HandleRead(w http.ResponseWriter, r *http.Request) error { userIDStr := chi.URLParam(r, "userID") userID, err := strconv.Atoi(userIDStr) if err != nil { - http.Error(w, "Invalid userID", http.StatusBadRequest) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Invalid userID: %v", err), + } } // Get DB instance @@ -66,18 +67,23 @@ func HandleCreate(w http.ResponseWriter, r *http.Request) error { e, ok := err.(*json.UnmarshalTypeError) if !ok { logrus.Error(err) - return err + return apierrors.ClientBadRequestError{ + Message: fmt.Sprintf("Bad JSON parsing: %v", err), + } } // TODO: Investigate if we should use errors.Wrap instead - return apierrors.ClientBadRequestError{ + return apierrors.ClientUnprocessableEntityError{ Message: fmt.Sprintf("Invalid JSON format: %s should be a %s.", e.Field, e.Type), } } err := params.Validate() if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + logrus.Error(err) + return apierrors.ClientUnprocessableEntityError{ + Message: fmt.Sprintf("JSON validation failed: %v", err), + } } userModel := *params.ToModel() diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go index 8ce5d54..7fb75b8 100644 --- a/internal/auth/middleware.go +++ b/internal/auth/middleware.go @@ -7,6 +7,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/source-academy/stories-backend/internal/config" + apierrors "github.com/source-academy/stories-backend/internal/errors" envutils "github.com/source-academy/stories-backend/internal/utils/env" ) @@ -22,6 +23,7 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { key, ok := keySet.Key(0) if !ok { // Block all access if JWKS source is down, since we can't verify JWTs + // TODO: Investigate if 500 is appropriate return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -34,7 +36,9 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { // Get JWT from request authHeader := r.Header.Get("Authorization") if authHeader == "" { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + apierrors.ServeHTTP(w, r, apierrors.ClientUnauthorizedError{ + Message: "Missing Authorization header", + }) return } @@ -42,7 +46,9 @@ func MakeMiddlewareFrom(conf *config.Config) func(http.Handler) http.Handler { toParse := authHeader[len("Bearer "):] token, err := jwt.ParseString(toParse, jwt.WithKey(jwa.RS256, key)) if err != nil { - fmt.Printf("Failed to verify JWS: %s\n", err) + apierrors.ServeHTTP(w, r, apierrors.ClientForbiddenError{ + Message: fmt.Sprintf("Failed to verify JWT: %s\n", err), + }) return } diff --git a/internal/errors/401.go b/internal/errors/401.go new file mode 100644 index 0000000..5c59322 --- /dev/null +++ b/internal/errors/401.go @@ -0,0 +1,17 @@ +package apierrors + +import ( + "net/http" +) + +type ClientUnauthorizedError struct { + Message string +} + +func (e ClientUnauthorizedError) Error() string { + return e.Message +} + +func (e ClientUnauthorizedError) HTTPStatusCode() int { + return http.StatusUnauthorized +} diff --git a/internal/errors/403.go b/internal/errors/403.go new file mode 100644 index 0000000..d986f5c --- /dev/null +++ b/internal/errors/403.go @@ -0,0 +1,17 @@ +package apierrors + +import ( + "net/http" +) + +type ClientForbiddenError struct { + Message string +} + +func (e ClientForbiddenError) Error() string { + return e.Message +} + +func (e ClientForbiddenError) HTTPStatusCode() int { + return http.StatusForbidden +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 1566ea1..0e948fb 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -1,8 +1,25 @@ package apierrors +import ( + "errors" + "net/http" +) + // ClientError is an interface for errors that should be returned to the client. // They generally start with a 4xx HTTP status code. type ClientError interface { error HTTPStatusCode() int } + +func ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { + var clientError ClientError + if errors.As(err, &clientError) { + // Client error (status 4xx), write error message and status code + http.Error(w, clientError.Error(), clientError.HTTPStatusCode()) + return + } + + // 500 Internal Server Error as a catch-all + http.Error(w, err.Error(), http.StatusInternalServerError) +} diff --git a/internal/router/errors.go b/internal/router/errors.go index 562d14e..01ad118 100644 --- a/internal/router/errors.go +++ b/internal/router/errors.go @@ -1,7 +1,6 @@ package router import ( - "errors" "net/http" apierrors "github.com/source-academy/stories-backend/internal/errors" @@ -15,14 +14,7 @@ func handleAPIError(handler func(w http.ResponseWriter, r *http.Request) error) return } - var clientError apierrors.ClientError - if errors.As(err, &clientError) { - // Client error (status 4xx), write error message and status code - http.Error(w, clientError.Error(), clientError.HTTPStatusCode()) - return - } - - // 500 Internal Server Error as a catch-all - http.Error(w, err.Error(), http.StatusInternalServerError) + // Error, write error message and status code + apierrors.ServeHTTP(w, r, err) } }