From 79f3cba222469afcb02b3527aa19329c0ed0abdc Mon Sep 17 00:00:00 2001 From: tarun-29 <29kantiwaltarun@gmail.com> Date: Thu, 3 Oct 2024 20:32:18 +0530 Subject: [PATCH] fix: modify trailing middleware and also added test cases --- internal/middleware/trailing_slash_test.go | 72 ++++++++++++++++++++++ internal/middleware/trailingslash.go | 16 ++--- 2 files changed, 80 insertions(+), 8 deletions(-) create mode 100644 internal/middleware/trailing_slash_test.go diff --git a/internal/middleware/trailing_slash_test.go b/internal/middleware/trailing_slash_test.go new file mode 100644 index 0000000..9ed6c8f --- /dev/null +++ b/internal/middleware/trailing_slash_test.go @@ -0,0 +1,72 @@ +package middleware_test + +import ( + "net/http" + "net/http/httptest" + "server/internal/middleware" + "testing" +) + +// demo test case +// { +// name: string +// requestURL: string, +// expectedCode: int, +// expectedLocation +// } + +func TestTrailingSlashMiddleware(t *testing.T) { + + handler := middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + tests := []struct { + name string + requestURL string + expectedCode int + expectedUrl string + }{ + { + name: "url with trailing slash", + requestURL: "/example/", + expectedCode: http.StatusMovedPermanently, + expectedUrl: "/example", + }, + { + name: "url without trailing slash", + requestURL: "/example", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "root url with trailing slash", + requestURL: "/", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + { + name: "URL with Query Parameters", + requestURL: "/example?query=1", + expectedCode: http.StatusOK, + expectedUrl: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("GET", tt.requestURL, nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != tt.expectedCode { + t.Errorf("expected status %d, got %d", tt.expectedCode, w.Code) + } + + if tt.expectedUrl != "" && w.Header().Get("Location") != tt.expectedUrl { + t.Errorf("expected location %s, got %s", tt.expectedUrl, w.Header().Get("Location")) + } + }) + } +} diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 9beb317..09ac1b7 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -5,19 +5,19 @@ import ( "strings" ) -// TrailingSlashMiddleware is a middleware function that removes the trailing slash from the URL path. func TrailingSlashMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if the URL path ends with a slash and is not the root path ("/") if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - // Remove the trailing slash - r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") - // Redirect to the new path (optional, for SEO) - http.Redirect(w, r, r.URL.Path, http.StatusMovedPermanently) + // remove slash + newPath := strings.TrimSuffix(r.URL.Path, "/") + // if query params exist append them + newURL := newPath + if r.URL.RawQuery != "" { + newURL += "?" + r.URL.RawQuery + } + http.Redirect(w, r, newURL, http.StatusMovedPermanently) return } - - // Call the next handler next.ServeHTTP(w, r) }) }