diff --git a/handlers/auth_test.go b/handlers/auth_test.go index ffcb3000a..2528e4688 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -39,8 +39,7 @@ func TestGetAdminPubkeys(t *testing.T) { handler.ServeHTTP(rr, req) - t.Errorf("handler returned wrong status code: got TEST want TEST") - if status := rr.Code; status != rr.Code { + if status := rr.Code; status != rr.Status { t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusOK) } diff --git a/handlers/bounty_test.go b/handlers/bounty_test.go index ee8567ef8..9e1ff1d3b 100644 --- a/handlers/bounty_test.go +++ b/handlers/bounty_test.go @@ -613,6 +613,7 @@ func TestDeleteBounty(t *testing.T) { } handler.ServeHTTP(rr, req) + //Check that this fails assert.Equal(t, http.StatusInternalServerError, rr.Code) }) diff --git a/routes/bounty.go b/routes/bounty.go index ea4eb4c6e..0974ae849 100644 --- a/routes/bounty.go +++ b/routes/bounty.go @@ -7,11 +7,15 @@ import ( "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" + "github.com/stakwork/sphinx-tribes/utils" ) func BountyRoutes() chi.Router { r := chi.NewRouter() bountyHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB) + + r.Use(utils.ErrorHandler) + r.Group(func(r chi.Router) { r.Get("/all", bountyHandler.GetAllBounties) diff --git a/routes/ticket_routes.go b/routes/ticket_routes.go index a45cd46c0..34ebfb8ea 100644 --- a/routes/ticket_routes.go +++ b/routes/ticket_routes.go @@ -7,12 +7,15 @@ import ( "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" + "github.com/stakwork/sphinx-tribes/utils" ) func TicketRoutes() chi.Router { r := chi.NewRouter() ticketHandler := handlers.NewTicketHandler(http.DefaultClient, db.DB) + r.Use(utils.ErrorHandler) + r.Group(func(r chi.Router) { r.Get("/{uuid}", ticketHandler.GetTicket) r.Post("/review", ticketHandler.ProcessTicketReview) diff --git a/utils/error_handler.go b/utils/error_handler.go new file mode 100644 index 000000000..f62195d47 --- /dev/null +++ b/utils/error_handler.go @@ -0,0 +1,54 @@ +package utils + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + + "net/http" + ) + +type customError struct { + error + StatusCode int +} + +func ErrorHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Add logic here to then send to jarvis with the correct values + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusNotFound) + } + }() + + ww := &responseWriterWrapper{ResponseWriter: w} + next.ServeHTTP(ww, r) + + if ww.error != nil { + //statusCode := http.StatusInternalServerError + statusCode := http.StatusNotFound + if errors.Is(ww.error, sql.ErrNoRows) { + statusCode = http.StatusNotFound + } else if err, ok := ww.error.(*customError); ok { + statusCode = err.StatusCode + } + + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(map[string]string{"error": ww.error.Error()}) + } + }) +} + +type responseWriterWrapper struct { + http.ResponseWriter + error error +} + +func (w *responseWriterWrapper) WriteHeader(statusCode int) { + if statusCode >= http.StatusBadRequest { + w.error = fmt.Errorf("HTTP %d: %s", statusCode, http.StatusText(statusCode)) + } + w.ResponseWriter.WriteHeader(statusCode) +}