From a9f47e75b2a238a752f46ed9034fe39c15f2a3ad Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Fri, 28 Jul 2023 23:20:59 +0100 Subject: [PATCH] Refactor table-driven test that was doing too much. This giant table-driven test was full of branches and was testing a whole bunch of different behaviours. - Split it into separate tests that test one thing each. - Stop doing exact comparisons on floating point numbers. - Clean up the other few tests in the module. - Use table-driven tests where appropriate. --- handlers/redirect_handler_test.go | 263 ++++++++++-------------------- 1 file changed, 82 insertions(+), 181 deletions(-) diff --git a/handlers/redirect_handler_test.go b/handlers/redirect_handler_test.go index adbbee4e..f6b48412 100644 --- a/handlers/redirect_handler_test.go +++ b/handlers/redirect_handler_test.go @@ -1,6 +1,7 @@ package handlers_test import ( + "fmt" "net/http" "net/http/httptest" "time" @@ -14,197 +15,97 @@ import ( "github.com/alphagov/router/handlers" ) -type redirectTableEntry struct { - preserve bool - temporary bool -} - -// TODO: refactor this abomination. -var _ = Describe("Redirect handlers", func() { - DescribeTable("handlers", - func(t redirectTableEntry) { - var redirectCode, redirectType string - rw := httptest.NewRecorder() - handler := handlers.NewRedirectHandler("/source-prefix", "/target-prefix", t.preserve, t.temporary) - - if t.temporary { - redirectCode = "302" - } else { - redirectCode = "301" - } - - if t.preserve { - redirectType = "path-preserving-redirect-handler" - } else { - redirectType = "redirect-handler" - } - - labels := prometheus.Labels{"redirect_code": redirectCode, "redirect_type": redirectType} - beforeCount := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(labels)) - - handler.ServeHTTP(rw, - httptest.NewRequest( - http.MethodGet, - "https://source.gov.uk/source-prefix/path/subpath?query1=a&query2=b", - nil, - ), - ) - - if t.temporary { - Expect(rw.Result().StatusCode).To( - Equal(http.StatusFound), - "when the redirect is temporary we should return HTTP 302", - ) - } else { - Expect(rw.Result().StatusCode).To( - Equal(http.StatusMovedPermanently), - "when the redirect is permanent we should return HTTP 301", - ) - } - - if t.preserve { - Expect(rw.Result().Header.Get("Location")).To( - Equal("/target-prefix/path/subpath?query1=a&query2=b"), - ) - } else { - Expect(rw.Result().Header.Get("Location")).To( - Equal("/target-prefix"), - "when we do not preserve the path, we redirect straight to target", - ) - } - - Expect(rw.Result().Header.Get("Cache-Control")).To( - SatisfyAll( - ContainSubstring("public"), - ContainSubstring("max-age=1800"), - ), - "Declare public and cachable for 30 minutes", - ) - - Expect(rw.Result().Header.Get("Expires")).To( - WithTransform( - func(timestr string) time.Time { - t, err := time.Parse(time.RFC1123, timestr) - Expect(err).NotTo(HaveOccurred(), "Not RFC1123 compliant") - return t - }, - BeTemporally("~", time.Now().Add(30*time.Minute), time.Second), - ), - "Be RFC1123 compliant and expire around 30 minutes in the future", - ) - - afterCount := promtest.ToFloat64( - handlers.RedirectHandlerRedirectCountMetric.With(labels), - ) - - Expect(afterCount-beforeCount).To( - BeNumerically("~", 1.0), - "Making a request should increment the redirect handler count metric", - ) - }, - Entry( - "when redirects are temporary and paths are preserved", - redirectTableEntry{preserve: true, temporary: true}, - ), - Entry( - "when redirects are temporary and paths are not preserved", - redirectTableEntry{preserve: false, temporary: true}, - ), - Entry( - "when redirects are not temporary and paths are preserved", - redirectTableEntry{preserve: true, temporary: false}, - ), - Entry( - "when redirects are not temporary and paths are not preserved", - redirectTableEntry{preserve: false, temporary: false}, - ), - ) +var _ = Describe("A redirect handler", func() { + var handler http.Handler + var rr *httptest.ResponseRecorder + const url = "https://source.example.com/source/path/subpath?q1=a&q2=b" + + BeforeEach(func() { + rr = httptest.NewRecorder() + }) - Context("when we are not preserving paths", func() { - var ( - rw *httptest.ResponseRecorder - handler http.Handler - ) + // These behaviours apply to all combinations of both NewRedirectHandler flags. + for _, preserve := range []bool{true, false} { + for _, temporary := range []bool{true, false} { + Context(fmt.Sprintf("where preserve=%t, temporary=%t", preserve, temporary), func() { + BeforeEach(func() { + handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) + }) + + It("allows its response to be cached publicly for 30m", func() { + Expect(rr.Result().Header.Get("Cache-Control")).To( + SatisfyAll(ContainSubstring("public"), ContainSubstring("max-age=1800"))) + }) + + It("returns an expires header with an RFC1123 datetime 30m in the future", func() { + Expect(rr.Result().Header.Get("Expires")).To(WithTransform( + func(s string) time.Time { + t, err := time.Parse(time.RFC1123, s) + Expect(err).NotTo(HaveOccurred()) + return t + }, + BeTemporally("~", time.Now().Add(30*time.Minute), time.Minute))) + }) + }) + } + } + Context("where preserve=true", func() { BeforeEach(func() { - rw = httptest.NewRecorder() + handler = handlers.NewRedirectHandler("/source", "/target", true, false) + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) + }) - handler = handlers.NewRedirectHandler( - "/source-prefix", "/target-prefix", - false, // preserve - true, // temporary - ) + It("returns the original path in the location header", func() { + Expect(rr.Result().Header.Get("Location")).To(Equal("/target/path/subpath?q1=a&q2=b")) }) + }) - Context("when the _ga query param is present", func() { - It("should persist _ga to the query params", func() { - handler.ServeHTTP( - rw, - httptest.NewRequest( - http.MethodGet, - "https://source.gov.uk/source-prefix?_ga=dontbeevil", - nil, - ), - ) - - Expect(rw.Result().Header.Get("Location")).To( - Equal("/target-prefix?_ga=dontbeevil"), - "Preserve the _ga query parameter", - ) - }) + Context("where preserve=false", func() { + BeforeEach(func() { + handler = handlers.NewRedirectHandler("/source", "/target", false, false) }) - Context("when the _ga query param is not present", func() { - It("should not add _ga to the query params", func() { - handler.ServeHTTP( - rw, - httptest.NewRequest( - http.MethodGet, - "https://source.gov.uk/source-prefix?param=begood", - nil, - ), - ) - - Expect(rw.Result().Header.Get("Location")).To( - Equal("/target-prefix"), - "Do not have any query params", - ) - }) + It("returns only the configured path in the location header", func() { + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) + Expect(rr.Result().Header.Get("Location")).To(Equal("/target")) }) - Context("metrics", func() { - It("should increment the metric with redirect-handler label", func() { - labels := prometheus.Labels{ - "redirect_code": "302", - "redirect_type": "redirect-handler", - } - - beforeCount := promtest.ToFloat64( - handlers.RedirectHandlerRedirectCountMetric.With(labels), - ) - - handler.ServeHTTP( - rw, - httptest.NewRequest( - http.MethodGet, - "https://source.gov.uk/source-prefix", - nil, - ), - ) - - Expect(rw.Result().Header.Get("Location")).To( - Equal("/target-prefix"), - ) - - afterCount := promtest.ToFloat64( - handlers.RedirectHandlerRedirectCountMetric.With(labels), - ) - - Expect(afterCount-beforeCount).To( - Equal(1.0), - "Making a request should increment the redirect handler count metric", - ) - }) + It("still preserves the _ga query parameter as a special case", func() { + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, + "https://source.example.com/source?_ga=dontbeevil", nil)) + Expect(rr.Result().Header.Get("Location")).To(Equal("/target?_ga=dontbeevil")) }) }) + + DescribeTable("responds with the right HTTP status", + EntryDescription("preserve=%t, temporary=%t -> HTTP %d"), + Entry(nil, false, false, http.StatusMovedPermanently), + Entry(nil, false, true, http.StatusFound), + Entry(nil, true, false, http.StatusMovedPermanently), + Entry(nil, true, true, http.StatusFound), + func(preserve, temporary bool, expectedStatus int) { + handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) + Expect(rr.Result().StatusCode).To(Equal(expectedStatus)) + }) + + DescribeTable("increments the redirect-count metric with the right labels", + EntryDescription("preserve=%t, temporary=%t -> {redirect_code=%s,redirect_type=%s}"), + Entry(nil, false, false, "301", "redirect-handler"), + Entry(nil, false, true, "302", "redirect-handler"), + Entry(nil, true, false, "301", "path-preserving-redirect-handler"), + Entry(nil, true, true, "302", "path-preserving-redirect-handler"), + func(preserve, temporary bool, codeLabel, typeLabel string) { + lbls := prometheus.Labels{"redirect_code": codeLabel, "redirect_type": typeLabel} + before := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + + handler = handlers.NewRedirectHandler("/source", "/target", preserve, temporary) + handler.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, url, nil)) + + after := promtest.ToFloat64(handlers.RedirectHandlerRedirectCountMetric.With(lbls)) + Expect(after - before).To(BeNumerically("~", 1.0)) + }, + ) })