diff --git a/auth/auth_test.go b/auth/auth_test.go index be7b2d985..d66f638e2 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -1231,3 +1231,160 @@ func TestPubKeyContextSuperAdmin(t *testing.T) { }) } + +func TestCypressContexts(t *testing.T) { + tests := []struct { + name string + isFreePass bool + contextKey interface{} + expectedStatus int + expectNextCalled bool + }{ + { + name: "Free Pass Allowed", + isFreePass: true, + contextKey: "", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Free Pass Disabled", + isFreePass: false, + contextKey: "", + expectedStatus: http.StatusUnauthorized, + expectNextCalled: false, + }, + { + name: "Empty Context Key", + isFreePass: true, + contextKey: "", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Multiple Requests with Free Pass", + isFreePass: true, + contextKey: "", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Multiple Requests without Free Pass", + isFreePass: false, + contextKey: "", + expectedStatus: http.StatusUnauthorized, + expectNextCalled: false, + }, + { + name: "Invalid Context Key Type", + isFreePass: true, + contextKey: 12345, + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Empty Request with Free Pass", + isFreePass: true, + contextKey: "", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Empty Request without Free Pass", + isFreePass: false, + contextKey: "", + expectedStatus: http.StatusUnauthorized, + expectNextCalled: false, + }, + { + name: "Null Context with Free Pass", + isFreePass: true, + contextKey: "", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Nil Request Context", + isFreePass: true, + contextKey: "testKey", + expectedStatus: http.StatusOK, + expectNextCalled: true, + }, + { + name: "Null Context without Free Pass", + isFreePass: false, + contextKey: "", + expectedStatus: http.StatusUnauthorized, + expectNextCalled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + config.AdminStrings = "" + if !tt.isFreePass { + config.AdminStrings = "non-empty" + } + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + + handler := CypressContext(next) + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + assert.Equal(t, tt.expectNextCalled, nextCalled) + + if !tt.expectNextCalled { + assert.Equal(t, http.StatusText(http.StatusUnauthorized)+"\n", rr.Body.String()) + } + }) + } + + t.Run("Null Request Object", func(t *testing.T) { + config.AdminStrings = "non-empty" + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + handler := CypressContext(next) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, nil) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.False(t, nextCalled) + assert.Equal(t, http.StatusText(http.StatusUnauthorized)+"\n", rr.Body.String()) + }) + + t.Run("Large Number of Requests", func(t *testing.T) { + config.AdminStrings = "" + + nextCalled := 0 + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled++ + w.WriteHeader(http.StatusOK) + }) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rr := httptest.NewRecorder() + + handler := CypressContext(next) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + } + + assert.Equal(t, 1000, nextCalled) + }) +}