From 37b11149da9aef842266f07b1e7e10fd645eac27 Mon Sep 17 00:00:00 2001 From: tarun-29 <29kantiwaltarun@gmail.com> Date: Wed, 2 Oct 2024 19:18:06 +0530 Subject: [PATCH 1/4] add: trailing slash middleware to prevent unexpected api crash --- internal/middleware/trailingslash.go | 23 +++++++++++++++++++++++ internal/server/httpServer.go | 9 ++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) create mode 100644 internal/middleware/trailingslash.go diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go new file mode 100644 index 0000000..9beb317 --- /dev/null +++ b/internal/middleware/trailingslash.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net/http" + "strings" +) + +// TrailingSlashMiddleware is a middleware function that removes the trailing slash from the URL path. +func TrailingSlashMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the URL path ends with a slash and is not the root path ("/") + if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { + // Remove the trailing slash + r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") + // Redirect to the new path (optional, for SEO) + http.Redirect(w, r, r.URL.Path, http.StatusMovedPermanently) + return + } + + // Call the next handler + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index be713f8..7b54334 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -28,9 +28,12 @@ type HandlerMux struct { func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Convert the path to lowercase before passing to the underlying mux. - r.URL.Path = strings.ToLower(r.URL.Path) - // Apply rate limiter - cim.rateLimiter(w, r, cim.mux) + middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.ToLower(r.URL.Path) + // Apply rate limiter + cim.rateLimiter(w, r, cim.mux) + })).ServeHTTP(w, r) + } func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit, window int) *HTTPServer { From 79f3cba222469afcb02b3527aa19329c0ed0abdc Mon Sep 17 00:00:00 2001 From: tarun-29 <29kantiwaltarun@gmail.com> Date: Thu, 3 Oct 2024 20:32:18 +0530 Subject: [PATCH 2/4] fix: modify trailing middleware and also added test cases --- internal/middleware/trailing_slash_test.go | 72 ++++++++++++++++++++++ internal/middleware/trailingslash.go | 16 ++--- 2 files changed, 80 insertions(+), 8 deletions(-) create mode 100644 internal/middleware/trailing_slash_test.go diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go new file mode 100644 index 0000000..9ed6c8f --- /dev/null +++ b/internal/middleware/trailing_slash_test.go @@ -0,0 +1,72 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/internal/middleware" + "testing" +) + +// demo test case +// { +// name: string +// requestURL: string, +// expectedCode: int, +// expectedLocation +// } + +func TestTrailingSlashMiddleware(t *testing.T) { + + handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + requestURL string + expectedCode int + expectedUrl string + }{ + { + name: "url with trailing slash", + requestURL: "/example/", + expectedCode: http.StatusMovedPermanently, + expectedUrl: "/example", + }, + { + name: "url without trailing slash", + requestURL: "/example", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "root url with trailing slash", + requestURL: "/", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "URL with Query Parameters", + requestURL: "/example?query=1", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.requestURL, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != tt.expectedCode { + t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { + t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + } + }) + } +} diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 9beb317..09ac1b7 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -5,19 +5,19 @@ import ( "strings" ) -// TrailingSlashMiddleware is a middleware function that removes the trailing slash from the URL path. func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if the URL path ends with a slash and is not the root path ("/") if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // Remove the trailing slash - r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") - // Redirect to the new path (optional, for SEO) - http.Redirect(w, r, r.URL.Path, http.StatusMovedPermanently) + // remove slash + newPath := strings.TrimSuffix(r.URL.Path, "/") + // if query params exist append them + newURL := newPath + if r.URL.RawQuery != "" { + newURL += "?" + r.URL.RawQuery + } + http.Redirect(w, r, newURL, http.StatusMovedPermanently) return } - - // Call the next handler next.ServeHTTP(w, r) }) } From 4afe3e30836391667a7063fd3537a24b29869bf5 Mon Sep 17 00:00:00 2001 From: tarun-29 <29kantiwaltarun@gmail.com> Date: Fri, 4 Oct 2024 16:36:46 +0530 Subject: [PATCH 3/4] fix: remove comments --- internal/middleware/trailing_slash_test.go | 8 -------- 1 file changed, 8 deletions(-) diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go index 9ed6c8f..203fb1a 100644 --- a/internal/middleware/trailing_slash_test.go +++ b/internal/middleware/trailing_slash_test.go @@ -7,14 +7,6 @@ import ( "testing" ) -// demo test case -// { -// name: string -// requestURL: string, -// expectedCode: int, -// expectedLocation -// } - func TestTrailingSlashMiddleware(t *testing.T) { handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { From b714acd000d37e408c2d4ec92ecc4bcfce13df4a Mon Sep 17 00:00:00 2001 From: tarun-29 <29kantiwaltarun@gmail.com> Date: Fri, 4 Oct 2024 20:46:36 +0530 Subject: [PATCH 4/4] fix: lint issue --- internal/server/httpServer.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 7b54334..c3a024c 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -33,7 +33,6 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Apply rate limiter cim.rateLimiter(w, r, cim.mux) })).ServeHTTP(w, r) - } func NewHTTPServer(addr string, mux *http.ServeMux, client *db.DiceDB, limit, window int) *HTTPServer {