diff --git a/integration_tests/backend_helpers.go b/integration_tests/backend_helpers.go index fd88d6c6..f5795bef 100644 --- a/integration_tests/backend_helpers.go +++ b/integration_tests/backend_helpers.go @@ -12,13 +12,19 @@ import ( "github.com/onsi/gomega/ghttp" ) -func startSimpleBackend(identifier string) *httptest.Server { +func startDummyBackend(id string, statusCode int) *httptest.Server { return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte(identifier)) + w.Header().Add("Server", id) + w.WriteHeader(statusCode) + _, err := w.Write([]byte(id)) Expect(err).NotTo(HaveOccurred()) })) } +func startSimpleBackend(id string) *httptest.Server { + return startDummyBackend(id, 200) +} + func startTarpitBackend(delays ...time.Duration) *httptest.Server { responseDelay := 2 * time.Second if len(delays) > 0 { @@ -29,18 +35,14 @@ func startTarpitBackend(delays ...time.Duration) *httptest.Server { bodyDelay = delays[1] } return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - body := "Tarpit\n" + const body = "Tarpit\n" - if responseDelay > 0 { - time.Sleep(responseDelay) - } + time.Sleep(responseDelay) w.Header().Add("Content-Length", strconv.Itoa(len(body))) w.WriteHeader(http.StatusOK) w.(http.Flusher).Flush() - if bodyDelay > 0 { - time.Sleep(bodyDelay) - } + time.Sleep(bodyDelay) _, err := w.Write([]byte(body)) Expect(err).NotTo(HaveOccurred()) })) diff --git a/integration_tests/disabled_routes_test.go b/integration_tests/disabled_routes_test.go index 7c1ae267..5da6f413 100644 --- a/integration_tests/disabled_routes_test.go +++ b/integration_tests/disabled_routes_test.go @@ -16,12 +16,12 @@ var _ = Describe("marking routes as disabled", func() { It("should return a 503 to the client", func() { resp := routerRequest(routerPort, "/unavailable") - Expect(resp.StatusCode).To(Equal(503)) + Expect(resp).To(HaveHTTPStatus(503)) }) It("should continue to route other requests", func() { resp := routerRequest(routerPort, "/something-live") - Expect(resp.StatusCode).To(Equal(301)) + Expect(resp).To(HaveHTTPStatus(301)) Expect(resp.Header.Get("Location")).To(Equal("/somewhere-else")) }) }) diff --git a/integration_tests/error_handling_test.go b/integration_tests/error_handling_test.go index 1c8cf534..e9166fe0 100644 --- a/integration_tests/error_handling_test.go +++ b/integration_tests/error_handling_test.go @@ -8,18 +8,18 @@ import ( ) var _ = Describe("error handling", func() { - - Describe("handling an empty routing table", func() { + Describe("when no routes are loaded", func() { BeforeEach(func() { reloadRoutes(apiPort) }) - It("should return a 503 error to the client", func() { + It("should forward to the default backend", func() { resp := routerRequest(routerPort, "/") - Expect(resp.StatusCode).To(Equal(503)) + Expect(resp).To(HaveHTTPHeaderWithValue("Server", "dummy-default-backend")) resp = routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(503)) + Expect(resp).To(HaveHTTPStatus(404)) + Expect(resp).To(HaveHTTPHeaderWithValue("Server", "dummy-default-backend")) }) }) @@ -31,7 +31,7 @@ var _ = Describe("error handling", func() { It("should return a 500 error to the client", func() { resp := routerRequest(routerPort, "/boom") - Expect(resp.StatusCode).To(Equal(500)) + Expect(resp).To(HaveHTTPStatus(500)) }) It("should log the fact", func() { diff --git a/integration_tests/gone_test.go b/integration_tests/gone_test.go index 990606cd..c9cbfdff 100644 --- a/integration_tests/gone_test.go +++ b/integration_tests/gone_test.go @@ -6,7 +6,6 @@ import ( ) var _ = Describe("Gone routes", func() { - BeforeEach(func() { addRoute("/foo", NewGoneRoute()) addRoute("/bar", NewGoneRoute("prefix")) @@ -15,21 +14,17 @@ var _ = Describe("Gone routes", func() { It("should support an exact gone route", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(410)) - Expect(readBody(resp)).To(Equal("410 Gone\n")) + Expect(resp).To(HaveHTTPStatus(410)) - resp = routerRequest(routerPort, "/foo/bar") - Expect(resp.StatusCode).To(Equal(404)) - Expect(readBody(resp)).To(Equal("404 page not found\n")) + resp = routerRequest(routerPort, "/foo/no-match") + Expect(resp).To(HaveHTTPStatus(404)) }) It("should support a prefix gone route", func() { resp := routerRequest(routerPort, "/bar") - Expect(resp.StatusCode).To(Equal(410)) - Expect(readBody(resp)).To(Equal("410 Gone\n")) + Expect(resp).To(HaveHTTPStatus(410)) resp = routerRequest(routerPort, "/bar/baz") - Expect(resp.StatusCode).To(Equal(410)) - Expect(readBody(resp)).To(Equal("410 Gone\n")) + Expect(resp).To(HaveHTTPStatus(410)) }) }) diff --git a/integration_tests/integration_test.go b/integration_tests/integration_test.go index a5fcbd34..a9a3c3c2 100644 --- a/integration_tests/integration_test.go +++ b/integration_tests/integration_test.go @@ -1,7 +1,6 @@ package integration import ( - "runtime" "testing" . "github.com/onsi/ginkgo/v2" @@ -14,9 +13,7 @@ func TestEverything(t *testing.T) { } var _ = BeforeSuite(func() { - runtime.GOMAXPROCS(runtime.NumCPU()) - var err error - err = setupTempLogfile() + err := setupTempLogfile() if err != nil { Fail(err.Error()) } diff --git a/integration_tests/metrics_test.go b/integration_tests/metrics_test.go index ba4812e9..09a83ed6 100644 --- a/integration_tests/metrics_test.go +++ b/integration_tests/metrics_test.go @@ -11,7 +11,7 @@ var _ = Describe("/metrics API endpoint", func() { BeforeEach(func() { resp := doRequest(newRequest("GET", routerURL(apiPort, "/metrics"))) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) responseBody = readBody(resp) }) diff --git a/integration_tests/proxy_function_test.go b/integration_tests/proxy_function_test.go index e9782347..5334b2c5 100644 --- a/integration_tests/proxy_function_test.go +++ b/integration_tests/proxy_function_test.go @@ -14,11 +14,11 @@ import ( "github.com/onsi/gomega/ghttp" ) -var _ = Describe("Functioning as a reverse proxy", func() { +var _ = Describe("As a reverse proxy", func() { var recorder *ghttp.Server - Describe("connecting to the backend", func() { - It("should return a 502 if the connection to the backend is refused", func() { + Describe("when connecting to the backend", func() { + It("should return 502 on backend connection refused", func() { addBackend("not-running", "http://127.0.0.1:3164/") addRoute("/not-running", NewBackendRoute("not-running")) reloadRoutes(apiPort) @@ -28,7 +28,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { req.Header.Set("X-Varnish", "12345678") resp := doRequest(req) - Expect(resp.StatusCode).To(Equal(502)) + Expect(resp).To(HaveHTTPStatus(502)) logDetails := lastRouterErrorLogEntry() Expect(logDetails.Fields).To(Equal(map[string]interface{}{ @@ -42,7 +42,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { Expect(logDetails.Timestamp).To(BeTemporally("~", time.Now(), time.Second)) }) - It("should log and return a 504 if the connection times out in the configured time", func() { + It("should log and return 504 on backend connection timeout", func() { err := startRouter(3167, 3166, []string{"ROUTER_BACKEND_CONNECT_TIMEOUT=0.3s"}) Expect(err).NotTo(HaveOccurred()) defer stopRouter(3167) @@ -59,8 +59,8 @@ var _ = Describe("Functioning as a reverse proxy", func() { resp := doRequest(req) duration := time.Since(start) - Expect(resp.StatusCode).To(Equal(504)) - Expect(duration).To(BeNumerically("~", 320*time.Millisecond, 20*time.Millisecond)) // 300 - 340 ms + Expect(resp).To(HaveHTTPStatus(504)) + Expect(duration).To(BeNumerically("~", 320*time.Millisecond, 20*time.Millisecond)) logDetails := lastRouterErrorLogEntry() Expect(logDetails.Fields).To(Equal(map[string]interface{}{ @@ -99,7 +99,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { req := newRequest(http.MethodGet, routerURL(3167, "/tarpit1")) req.Header.Set("X-Varnish", "12341112") resp := doRequest(req) - Expect(resp.StatusCode).To(Equal(504)) + Expect(resp).To(HaveHTTPStatus(504)) logDetails := lastRouterErrorLogEntry() tarpitURL, _ := url.Parse(tarpit1.URL) @@ -116,7 +116,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should still return the response if the body takes longer than the header timeout", func() { resp := routerRequest(3167, "/tarpit2") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(readBody(resp)).To(Equal("Tarpit\n")) }) }) @@ -130,16 +130,14 @@ var _ = Describe("Functioning as a reverse proxy", func() { reloadRoutes(apiPort) }) - AfterEach(func() { - recorder.Close() - }) + AfterEach(func() { recorder.Close() }) It("should pass through most http headers to the backend", func() { resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Foo": "bar", "User-Agent": "Router test suite 2.7182", }) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -151,7 +149,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { resp := routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Host": "www.example.com", }) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) recorderURL, err := url.Parse(recorder.URL()) Expect(err).NotTo(HaveOccurred()) @@ -164,7 +162,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should not add a default User-Agent if there isn't one in the request", func() { // Most http libraries add a default User-Agent header. resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -174,7 +172,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should add the client IP to X-Forwarded-For", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -183,7 +181,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "X-Forwarded-For": "10.9.8.7", }) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(2)) beReq = recorder.ReceivedRequests()[1] @@ -195,7 +193,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should add itself to the Via request header for an HTTP/1.1 request", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -204,7 +202,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { resp = routerRequestWithHeaders(routerPort, "/foo", map[string]string{ "Via": "1.0 fred, 1.1 barney", }) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(2)) beReq = recorder.ReceivedRequests()[1] @@ -214,7 +212,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should add itself to the Via request header for an HTTP/1.0 request", func() { req := newRequest(http.MethodGet, routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -224,7 +222,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { "Via": "1.0 fred, 1.1 barney", }) resp = doHTTP10Request(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(2)) beReq = recorder.ReceivedRequests()[1] @@ -233,20 +231,22 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should add itself to the Via response heaver", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(resp.Header.Get("Via")).To(Equal("1.1 router")) recorder.AppendHandlers(ghttp.RespondWith(200, "body", http.Header{ "Via": []string{"1.0 fred, 1.1 barney"}, })) resp = routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(resp.Header.Get("Via")).To(Equal("1.0 fred, 1.1 barney, 1.1 router")) }) }) }) - Describe("request verb, path, query and body handling", func() { + Describe("request method, path, query and body handling", func() { + var recorder *ghttp.Server + BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -254,11 +254,9 @@ var _ = Describe("Functioning as a reverse proxy", func() { reloadRoutes(apiPort) }) - AfterEach(func() { - recorder.Close() - }) + AfterEach(func() { recorder.Close() }) - It("should use the same verb and path when proxying", func() { + It("should use the same HTTP method and path when proxying", func() { recorder.AppendHandlers( ghttp.VerifyRequest("POST", "/foo"), ghttp.VerifyRequest("DELETE", "/foo/bar/baz.json"), @@ -266,21 +264,19 @@ var _ = Describe("Functioning as a reverse proxy", func() { req := newRequest("POST", routerURL(routerPort, "/foo")) resp := doRequest(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) req = newRequest("DELETE", routerURL(routerPort, "/foo/bar/baz.json")) resp = doRequest(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(2)) }) It("should pass through the query string unmodified", func() { - recorder.AppendHandlers( - ghttp.VerifyRequest("GET", "/foo/bar", "baz=qux"), - ) + recorder.AppendHandlers(ghttp.VerifyRequest("GET", "/foo/bar", "baz=qux")) resp := routerRequest(routerPort, "/foo/bar?baz=qux") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) }) @@ -296,13 +292,15 @@ var _ = Describe("Functioning as a reverse proxy", func() { req := newRequest("POST", routerURL(routerPort, "/foo")) req.Body = io.NopCloser(strings.NewReader("I am the request body. Woohoo!")) resp := doRequest(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) }) }) Describe("handling a backend with a non '/' path", func() { + var recorder *ghttp.Server + BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()+"/something") @@ -310,13 +308,11 @@ var _ = Describe("Functioning as a reverse proxy", func() { reloadRoutes(apiPort) }) - AfterEach(func() { - recorder.Close() - }) + AfterEach(func() { recorder.Close() }) It("should merge the 2 paths", func() { resp := routerRequest(routerPort, "/foo/bar") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -325,7 +321,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should preserve the request query string", func() { resp := routerRequest(routerPort, "/foo/bar?baz=qux") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -334,6 +330,8 @@ var _ = Describe("Functioning as a reverse proxy", func() { }) Describe("handling HTTP/1.0 requests", func() { + var recorder *ghttp.Server + BeforeEach(func() { recorder = startRecordingBackend() addBackend("backend", recorder.URL()) @@ -341,14 +339,12 @@ var _ = Describe("Functioning as a reverse proxy", func() { reloadRoutes(apiPort) }) - AfterEach(func() { - recorder.Close() - }) + AfterEach(func() { recorder.Close() }) It("should work with incoming HTTP/1.1 requests", func() { req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -358,7 +354,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should proxy to the backend as HTTP/1.1 requests", func() { req := newRequest("GET", routerURL(routerPort, "/foo")) resp := doHTTP10Request(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] @@ -384,7 +380,7 @@ var _ = Describe("Functioning as a reverse proxy", func() { It("should correctly reverse proxy to a HTTPS backend", func() { req := newRequest("GET", routerURL(3167, "/foo")) resp := doRequest(req) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) beReq := recorder.ReceivedRequests()[0] diff --git a/integration_tests/redirect_test.go b/integration_tests/redirect_test.go index 9bd20637..289e0c10 100644 --- a/integration_tests/redirect_test.go +++ b/integration_tests/redirect_test.go @@ -22,12 +22,12 @@ var _ = Describe("Redirection", func() { It("should redirect permanently by default", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(301)) + Expect(resp).To(HaveHTTPStatus(301)) }) It("should redirect temporarily when asked to", func() { resp := routerRequest(routerPort, "/foo-temp") - Expect(resp.StatusCode).To(Equal(302)) + Expect(resp).To(HaveHTTPStatus(302)) }) It("should contain the redirect location", func() { @@ -79,13 +79,13 @@ var _ = Describe("Redirection", func() { It("should redirect permanently to the destination", func() { resp := routerRequest(routerPort, "/foo") - Expect(resp.StatusCode).To(Equal(301)) + Expect(resp).To(HaveHTTPStatus(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar")) }) It("should redirect temporarily to the destination when asked to", func() { resp := routerRequest(routerPort, "/foo-temp") - Expect(resp.StatusCode).To(Equal(302)) + Expect(resp).To(HaveHTTPStatus(302)) Expect(resp.Header.Get("Location")).To(Equal("/bar-temp")) }) @@ -127,7 +127,7 @@ var _ = Describe("Redirection", func() { reloadRoutes(apiPort) resp := routerRequest(routerPort, "/foo bar/something") - Expect(resp.StatusCode).To(Equal(301)) + Expect(resp).To(HaveHTTPStatus(301)) Expect(resp.Header.Get("Location")).To(Equal("/bar%20baz/something")) }) }) diff --git a/integration_tests/reload_api_test.go b/integration_tests/reload_api_test.go index 0b52cd17..3947f85a 100644 --- a/integration_tests/reload_api_test.go +++ b/integration_tests/reload_api_test.go @@ -12,23 +12,23 @@ var _ = Describe("reload API endpoint", func() { Describe("request handling", func() { It("should return 202 for POST /reload", func() { resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload"))) - Expect(resp.StatusCode).To(Equal(202)) + Expect(resp).To(HaveHTTPStatus(202)) Expect(readBody(resp)).To(Equal("Reload queued")) }) It("should return 404 for POST /foo", func() { resp := doRequest(newRequest("POST", routerURL(apiPort, "/foo"))) - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should return 404 for POST /reload/foo", func() { resp := doRequest(newRequest("POST", routerURL(apiPort, "/reload/foo"))) - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should return 405 for GET /reload", func() { resp := doRequest(newRequest("GET", routerURL(apiPort, "/reload"))) - Expect(resp.StatusCode).To(Equal(405)) + Expect(resp).To(HaveHTTPStatus(405)) Expect(resp.Header.Get("Allow")).To(Equal("POST")) }) @@ -50,13 +50,13 @@ var _ = Describe("reload API endpoint", func() { Describe("healthcheck", func() { It("should return HTTP 200 OK on GET", func() { resp := doRequest(newRequest("GET", routerURL(apiPort, "/healthcheck"))) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(readBody(resp)).To(Equal("OK")) }) It("should return HTTP 405 Method Not Allowed on POST", func() { resp := doRequest(newRequest("POST", routerURL(apiPort, "/healthcheck"))) - Expect(resp.StatusCode).To(Equal(405)) + Expect(resp).To(HaveHTTPStatus(405)) Expect(resp.Header.Get("Allow")).To(Equal("GET")) }) }) @@ -69,7 +69,7 @@ var _ = Describe("reload API endpoint", func() { reloadRoutes(apiPort) resp := doRequest(newRequest("GET", routerURL(apiPort, "/memory-stats"))) - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) var data map[string]interface{} readJSONBody(resp, &data) diff --git a/integration_tests/route_loading_test.go b/integration_tests/route_loading_test.go index 827c1deb..417d0cd0 100644 --- a/integration_tests/route_loading_test.go +++ b/integration_tests/route_loading_test.go @@ -36,7 +36,7 @@ var _ = Describe("loading routes from the db", func() { It("should skip the invalid route", func() { resp := routerRequest(routerPort, "/bar") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should continue to load other routes", func() { @@ -59,7 +59,7 @@ var _ = Describe("loading routes from the db", func() { It("should skip the invalid route", func() { resp := routerRequest(routerPort, "/bar") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should continue to load other routes", func() { @@ -99,7 +99,7 @@ var _ = Describe("loading routes from the db", func() { It("should send requests to the backend_url provided in the env var", func() { resp := routerRequest(routerPort, "/oof") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(readBody(resp)).To(Equal("backend 3")) }) }) diff --git a/integration_tests/route_selection_test.go b/integration_tests/route_selection_test.go index 4b60e44b..579d388b 100644 --- a/integration_tests/route_selection_test.go +++ b/integration_tests/route_selection_test.go @@ -44,18 +44,18 @@ var _ = Describe("Route selection", func() { It("should 404 for children of the exact route", func() { resp := routerRequest(routerPort, "/foo/bar") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should 404 for non-matching requests", func() { resp := routerRequest(routerPort, "/wibble") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) resp = routerRequest(routerPort, "/") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) resp = routerRequest(routerPort, "/foo.json") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) }) @@ -104,13 +104,13 @@ var _ = Describe("Route selection", func() { It("should 404 for non-matching requests", func() { resp := routerRequest(routerPort, "/wibble") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) resp = routerRequest(routerPort, "/") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) resp = routerRequest(routerPort, "/foo.json") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) }) @@ -296,7 +296,7 @@ var _ = Describe("Route selection", func() { Expect(readBody(resp)).To(Equal("other backend")) resp = routerRequest(routerPort, "/bar") - Expect(resp.StatusCode).To(Equal(404)) + Expect(resp).To(HaveHTTPStatus(404)) }) It("should handle a prefix route at the root level", func() { @@ -341,14 +341,14 @@ var _ = Describe("Route selection", func() { It("should not be redirected by our recorder backend", func() { resp := routerRequest(routerPort, "/foo/bar/baz//qux") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo/bar/baz//qux")) }) It("should collapse double slashes when looking up route, but pass request as-is", func() { resp := routerRequest(routerPort, "/foo//bar") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].URL.Path).To(Equal("/foo//bar")) }) @@ -370,7 +370,7 @@ var _ = Describe("Route selection", func() { reloadRoutes(apiPort) resp := routerRequest(routerPort, "/foo bar") - Expect(resp.StatusCode).To(Equal(200)) + Expect(resp).To(HaveHTTPStatus(200)) Expect(recorder.ReceivedRequests()).To(HaveLen(1)) Expect(recorder.ReceivedRequests()[0].RequestURI).To(Equal("/foo%20bar")) }) diff --git a/integration_tests/router_support.go b/integration_tests/router_support.go index d03d526f..46193114 100644 --- a/integration_tests/router_support.go +++ b/integration_tests/router_support.go @@ -36,7 +36,7 @@ func reloadRoutes(port int) { resp, err := http.DefaultClient.Do(req) Expect(err).NotTo(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(202)) + Expect(resp).To(HaveHTTPStatus(202)) resp.Body.Close() // Now that reloading is done asynchronously, we need a small sleep to ensure // it has actually been performed. @@ -50,6 +50,8 @@ func startRouter(port, apiPort int, extraEnv []string) error { pubAddr := net.JoinHostPort(host, strconv.Itoa(port)) apiAddr := net.JoinHostPort(host, strconv.Itoa(apiPort)) + defaultBackend := startDummyBackend("dummy-default-backend", 404) + bin := os.Getenv("BINARY") if bin == "" { bin = "../router" @@ -59,6 +61,7 @@ func startRouter(port, apiPort int, extraEnv []string) error { cmd.Env = append(cmd.Environ(), "ROUTER_MONGO_DB=router_test") cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_PUBADDR=%s", pubAddr)) cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_APIADDR=%s", apiAddr)) + cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_DEFAULT_BACKEND_URL=%s", defaultBackend.URL)) cmd.Env = append(cmd.Env, fmt.Sprintf("ROUTER_ERROR_LOG=%s", tempLogfile.Name())) cmd.Env = append(cmd.Env, extraEnv...) diff --git a/lib/router.go b/lib/router.go index a05ae423..d2621917 100644 --- a/lib/router.go +++ b/lib/router.go @@ -2,6 +2,7 @@ package router import ( "fmt" + "log" "net/http" "net/url" "os" @@ -40,6 +41,7 @@ const ( // MongoReplicaSet, MongoReplicaSetMember etc. should move out of this module. type Router struct { mux *triemux.Mux + defaultHandler http.Handler lock sync.RWMutex mongoReadToOptime bson.MongoTimestamp logger logger.Logger @@ -92,7 +94,7 @@ func RegisterMetrics(r prometheus.Registerer) { // NewRouter returns a new empty router instance. You will need to call // SelfUpdateRoutes() to initialise the self-update process for routes. -func NewRouter(o Options) (rt *Router, err error) { +func NewRouter(defaultBackend *url.URL, o Options) (rt *Router, err error) { logInfo("router: using mongo poll interval:", o.MongoPollInterval) logInfo("router: using backend connect timeout:", o.BackendConnTimeout) logInfo("router: using backend header timeout:", o.BackendHeaderTimeout) @@ -108,9 +110,22 @@ func NewRouter(o Options) (rt *Router, err error) { return nil, err } + defaultHandler := handlers.NewBackendHandler( + "default", + defaultBackend, + o.BackendConnTimeout, + o.BackendHeaderTimeout, + l, + ) + if err != nil { + log.Fatal(err) + } + logDebug("defaultHandler:", defaultHandler) + reloadChan := make(chan bool, 1) rt = &Router{ - mux: triemux.NewMux(), + mux: triemux.NewMux(defaultHandler), + defaultHandler: defaultHandler, mongoReadToOptime: mongoReadToOptime, logger: l, opts: o, @@ -235,7 +250,7 @@ func (rt *Router) reloadRoutes(db *mgo.Database, currentOptime bson.MongoTimesta }() logInfo("router: reloading routes") - newmux := triemux.NewMux() + newmux := triemux.NewMux(rt.defaultHandler) backends := rt.loadBackends(db.C("backends")) loadRoutes(db.C("routes"), newmux, backends) diff --git a/main.go b/main.go index f2d878e3..4b8932f8 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "net/http" + "net/url" "os" "runtime" "time" @@ -26,6 +27,7 @@ ROUTER_APIADDR=:8081 Address on which to receive reload requests ROUTER_MONGO_URL=127.0.0.1 Address of mongo cluster (e.g. 'mongo1,mongo2,mongo3') ROUTER_MONGO_DB=router Name of mongo database to use ROUTER_MONGO_POLL_INTERVAL=2s Interval to poll mongo for route changes +ROUTER_DEFAULT_BACKEND_URL Where to forward requests that don't match any route ROUTER_ERROR_LOG=STDERR File to log errors to (in JSON format) ROUTER_DEBUG= Enable debug output if non-empty @@ -96,6 +98,7 @@ func main() { beHeaderTimeout = getenvDuration("ROUTER_BACKEND_HEADER_TIMEOUT", "20s") feReadTimeout = getenvDuration("ROUTER_FRONTEND_READ_TIMEOUT", "60s") feWriteTimeout = getenvDuration("ROUTER_FRONTEND_WRITE_TIMEOUT", "60s") + defaultBackend = getenv("ROUTER_DEFAULT_BACKEND_URL", "http://government-frontend") ) log.Printf("using frontend read timeout: %v", feReadTimeout) @@ -110,7 +113,12 @@ func main() { router.RegisterMetrics(prometheus.DefaultRegisterer) - rout, err := router.NewRouter(router.Options{ + d, err := url.Parse(defaultBackend) + if err != nil { + log.Fatal(err) + } + + rout, err := router.NewRouter(d, router.Options{ MongoURL: mongoURL, MongoDBName: mongoDBName, MongoPollInterval: mongoPollInterval, diff --git a/triemux/metrics.go b/triemux/metrics.go index 187e7177..de415f40 100644 --- a/triemux/metrics.go +++ b/triemux/metrics.go @@ -5,13 +5,6 @@ import ( ) var ( - entryNotFoundCountMetric = prometheus.NewCounter( - prometheus.CounterOpts{ - Name: "router_triemux_entry_not_found_total", - Help: "Number of route lookups for which no route was found", - }, - ) - internalServiceUnavailableCountMetric = prometheus.NewCounter( prometheus.CounterOpts{ Name: "router_service_unavailable_error_total", @@ -22,7 +15,6 @@ var ( func RegisterMetrics(r prometheus.Registerer) { r.MustRegister( - entryNotFoundCountMetric, internalServiceUnavailableCountMetric, ) } diff --git a/triemux/mux.go b/triemux/mux.go index 57e09d18..ce9912fd 100644 --- a/triemux/mux.go +++ b/triemux/mux.go @@ -10,12 +10,12 @@ import ( "sync" "github.com/alphagov/router/handlers" - "github.com/alphagov/router/logger" "github.com/alphagov/router/trie" ) type Mux struct { mu sync.RWMutex + defaultHandler http.Handler exactTrie *trie.Trie[http.Handler] prefixTrie *trie.Trie[http.Handler] count int @@ -23,40 +23,25 @@ type Mux struct { } // NewMux makes a new empty Mux. -func NewMux() *Mux { +func NewMux(defaultHandler http.Handler) *Mux { return &Mux{ + defaultHandler: defaultHandler, exactTrie: trie.NewTrie[http.Handler](), prefixTrie: trie.NewTrie[http.Handler](), downcaser: handlers.NewDowncaseRedirectHandler(), } } -// ServeHTTP forwards the request to a backend with a registered route matching -// the request path. Serves 404 when there is no backend. Serves 301 redirect -// to lowercase path when the URL path is entirely uppercase. Serves 503 when -// no routes are loaded. +// ServeHTTP forwards the request to the backend based on the longest-matching +// URL path prefix, or to the default backend if there is no match. Serves a +// 301 redirect to lowercase path when the URL path is entirely uppercase. func (mux *Mux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if mux.count == 0 { - w.WriteHeader(http.StatusServiceUnavailable) - logger.NotifySentry(logger.ReportableError{ - Error: logger.RecoveredError{ErrorMessage: "route table is empty"}, - Request: r, - }) - internalServiceUnavailableCountMetric.Inc() - return - } - if shouldRedirToLowercasePath(r.URL.Path) { mux.downcaser.ServeHTTP(w, r) return } - handler, ok := mux.lookup(r.URL.Path) - if !ok { - http.NotFound(w, r) - return - } - handler.ServeHTTP(w, r) + mux.lookup(r.URL.Path).ServeHTTP(w, r) } // shouldRedirToLowercasePath takes a URL path string (such as "/government/guidance") @@ -71,19 +56,19 @@ func shouldRedirToLowercasePath(path string) (match bool) { } // lookup finds a URL path in the Mux and returns the corresponding handler. -func (mux *Mux) lookup(path string) (handler http.Handler, ok bool) { +func (mux *Mux) lookup(path string) http.Handler { mux.mu.RLock() defer mux.mu.RUnlock() pathSegments := splitPath(path) - if handler, ok = mux.exactTrie.Get(pathSegments); !ok { + handler, ok := mux.exactTrie.Get(pathSegments) + if !ok { handler, ok = mux.prefixTrie.GetLongestPrefix(pathSegments) } if !ok { - entryNotFoundCountMetric.Inc() - return nil, false + return mux.defaultHandler } - return + return handler } // Handle adds a route (either an exact path or a path prefix) to the Mux and diff --git a/triemux/mux_test.go b/triemux/mux_test.go index 8da612b9..1561fe3f 100644 --- a/triemux/mux_test.go +++ b/triemux/mux_test.go @@ -6,8 +6,6 @@ import ( "os" "strings" "testing" - - promtest "github.com/prometheus/client_golang/prometheus/testutil" ) func TestSplitPath(t *testing.T) { @@ -172,48 +170,19 @@ var lookupExamples = []LookupExample{ } func TestLookup(t *testing.T) { - beforeCount := promtest.ToFloat64(entryNotFoundCountMetric) - for _, ex := range lookupExamples { testLookup(t, ex) } - - afterCount := promtest.ToFloat64(entryNotFoundCountMetric) - notFoundCount := afterCount - beforeCount - - var expectedNotFoundCount int - - for _, ex := range lookupExamples { - for _, c := range ex.checks { - if !c.ok { - expectedNotFoundCount++ - } - } - } - - if expectedNotFoundCount == 0 { - t.Errorf("expectedNotFoundCount should not be zero") - } - - if notFoundCount != float64(expectedNotFoundCount) { - t.Errorf( - "Expected notFoundCount (%f) ok to be %f", - notFoundCount, float64(expectedNotFoundCount), - ) - } } func testLookup(t *testing.T, ex LookupExample) { - mux := NewMux() + mux := NewMux(nil) for _, r := range ex.registrations { t.Logf("Register(path:%v, prefix:%v, handler:%v)", r.path, r.prefix, r.handler) mux.Handle(r.path, r.prefix, r.handler) } for _, c := range ex.checks { - handler, ok := mux.lookup(c.path) - if ok != c.ok { - t.Errorf("Expected lookup(%v) ok to be %v, was %v", c.path, c.ok, ok) - } + handler := mux.lookup(c.path) if handler != c.handler { t.Errorf("Expected lookup(%v) to map to handler %v, was %v", c.path, c.handler, handler) } @@ -227,7 +196,7 @@ var statsExample = []Registration{ } func TestRouteCount(t *testing.T) { - mux := NewMux() + mux := NewMux(nil) for _, reg := range statsExample { mux.Handle(reg.path, reg.prefix, reg.handler) } @@ -248,7 +217,7 @@ func loadStrings(filename string) []string { func benchSetup() *Mux { routes := loadStrings("testdata/routes") - tm := NewMux() + tm := NewMux(nil) tm.Handle("/government", true, a) for _, l := range routes { @@ -257,8 +226,7 @@ func benchSetup() *Mux { return tm } -// Test behaviour looking up extant urls -func BenchmarkLookup(b *testing.B) { +func BenchmarkLookupFound(b *testing.B) { b.StopTimer() tm := benchSetup() urls := loadStrings("testdata/urls") @@ -270,8 +238,7 @@ func BenchmarkLookup(b *testing.B) { } } -// Test behaviour when looking up nonexistent urls -func BenchmarkLookupBogus(b *testing.B) { +func BenchmarkLookupNotFound(b *testing.B) { b.StopTimer() tm := benchSetup() urls := loadStrings("testdata/bogus") @@ -283,9 +250,7 @@ func BenchmarkLookupBogus(b *testing.B) { } } -// Test worst-case lookup behaviour (see comment in findlongestmatch for -// details) -func BenchmarkLookupMalicious(b *testing.B) { +func BenchmarkLookupWorstCase(b *testing.B) { b.StopTimer() tm := benchSetup() b.StartTimer()