Skip to content

Commit

Permalink
middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
kevkevinpal committed Dec 11, 2024
1 parent 79341e5 commit 7c554c9
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 11 deletions.
2 changes: 1 addition & 1 deletion handlers/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestGetAdminPubkeys(t *testing.T) {

handler.ServeHTTP(rr, req)

if status := rr.Code; status != rr.Status {
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
Expand Down
4 changes: 0 additions & 4 deletions routes/bounty.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@ import (
"github.com/stakwork/sphinx-tribes/auth"
"github.com/stakwork/sphinx-tribes/db"
"github.com/stakwork/sphinx-tribes/handlers"
"github.com/stakwork/sphinx-tribes/utils"
)

func BountyRoutes() chi.Router {
r := chi.NewRouter()
bountyHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB)

r.Use(utils.ErrorHandler)

r.Group(func(r chi.Router) {
r.Get("/all", bountyHandler.GetAllBounties)

Expand Down
27 changes: 27 additions & 0 deletions routes/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,38 @@ func getFromAuth(path string) (*extractResponse, error) {
}, nil
}

// Middleware to handle InternalServerError
func internalServerErrorHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use a ResponseWriter that allows capturing the status code
rr := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rr, r)

fmt.Println("Inside handler")
// Check for Internal Server Error
if rr.statusCode == http.StatusOK{
fmt.Println("Internal Server Error: %s %s", r.Method, r.URL.Path)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
if rr.statusCode == http.StatusInternalServerError {
fmt.Println("Internal Server Error: %s %s", r.Method, r.URL.Path)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
})
}

// Custom ResponseWriter to capture status codes
type responseRecorder struct {
http.ResponseWriter
statusCode int
}

func initChi() *chi.Mux {
r := chi.NewRouter()
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Use(internalServerErrorHandler)
cors := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
Expand Down
3 changes: 0 additions & 3 deletions routes/ticket_routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@ import (
"github.com/stakwork/sphinx-tribes/auth"
"github.com/stakwork/sphinx-tribes/db"
"github.com/stakwork/sphinx-tribes/handlers"
"github.com/stakwork/sphinx-tribes/utils"
)

func TicketRoutes() chi.Router {
r := chi.NewRouter()
ticketHandler := handlers.NewTicketHandler(http.DefaultClient, db.DB)

r.Use(utils.ErrorHandler)

r.Group(func(r chi.Router) {
r.Get("/{uuid}", ticketHandler.GetTicket)
r.Post("/review", ticketHandler.ProcessTicketReview)
Expand Down
21 changes: 18 additions & 3 deletions utils/error_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,26 @@ import (
"net/http"
)

type customError struct {
error
type CustomError struct {
Err error
StatusCode int
}

func NewCustomError(err error, statusCode int) *CustomError {
return &CustomError{
Err: err,
StatusCode: statusCode,
}
}

// Error implements the error interface.
func (e *CustomError) Error() string {
if e.Err != nil {
return e.Err.Error()
}
return fmt.Sprintf("HTTP %d", e.StatusCode)
}

func ErrorHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
Expand All @@ -31,7 +46,7 @@ func ErrorHandler(next http.Handler) http.Handler {
statusCode := http.StatusNotFound
if errors.Is(ww.error, sql.ErrNoRows) {
statusCode = http.StatusNotFound
} else if err, ok := ww.error.(*customError); ok {
} else if err, ok := ww.error.(*CustomError); ok {
statusCode = err.StatusCode
}

Expand Down

0 comments on commit 7c554c9

Please sign in to comment.