diff --git a/routes/index.go b/routes/index.go index 030a04db8..302117b8d 100644 --- a/routes/index.go +++ b/routes/index.go @@ -1,15 +1,18 @@ package routes import ( + "bufio" "encoding/json" "fmt" "io" + "net" "net/http" "os" "time" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" + "github.com/gorilla/websocket" "github.com/rs/cors" "github.com/stakwork/sphinx-tribes/auth" @@ -63,6 +66,7 @@ func NewRouter() *http.Server { r.Post("/save", db.PostSave) r.Get("/save/{key}", db.PollSave) r.Get("/migrate_bounties", handlers.MigrateBounties) + r.Get("/test/internal-server-error", testInternalServerError) r.Get("/websocket", handlers.HandleWebSocket) }) @@ -141,36 +145,66 @@ func getFromAuth(path string) (*extractResponse, error) { }, nil } -// Middleware to handle InternalServerError -// func internalServerErrorHandler(next http.Handler) http.Handler { -// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// rr := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK} -// next.ServeHTTP(rr, r) - -// if rr.statusCode == http.StatusInternalServerError { -// fmt.Printf("Internal Server Error: %s %s\n", r.Method, r.URL.Path) -// http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) -// } -// }) -// } - -// Custom ResponseWriter to capture status codes -// type responseRecorder struct { -// http.ResponseWriter -// statusCode int -// } - -// func (rr *responseRecorder) WriteHeader(code int) { -// rr.statusCode = code -// rr.ResponseWriter.WriteHeader(code) -// } +type responseRecorder struct { + http.ResponseWriter + statusCode int + written bool +} + +func (rr *responseRecorder) WriteHeader(code int) { + if !rr.written { + rr.statusCode = code + rr.written = true + rr.ResponseWriter.WriteHeader(code) + } +} + +func (rr *responseRecorder) Write(b []byte) (int, error) { + if !rr.written { + rr.statusCode = http.StatusOK + rr.written = true + } + return rr.ResponseWriter.Write(b) +} + +func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := rr.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, fmt.Errorf("hijacking not supported") +} + +func internalServerErrorHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if websocket.IsWebSocketUpgrade(r) { + next.ServeHTTP(w, r) + return + } + + rr := &responseRecorder{ + ResponseWriter: w, + statusCode: http.StatusOK, + written: false, + } + + next.ServeHTTP(rr, r) + + if rr.statusCode == http.StatusInternalServerError { + fmt.Printf("Inside Internal Server Middleware: %s %s\n", r.Method, r.URL.Path) + } + }) +} + +func testInternalServerError(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) +} func initChi() *chi.Mux { r := chi.NewRouter() r.Use(middleware.RequestID) r.Use(middleware.Logger) r.Use(middleware.Recoverer) - // r.Use(internalServerErrorHandler) + r.Use(internalServerErrorHandler) cors := cors.New(cors.Options{ AllowedOrigins: []string{"*"}, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},