From a0cf034ae165153124cf50cbf901858b2e1b857b Mon Sep 17 00:00:00 2001 From: elraphty Date: Thu, 5 Dec 2024 15:21:39 +0100 Subject: [PATCH] re: fix merged conflicts, and change error codes --- auth/auth.go | 32 ++++++----- auth/auth_test.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 13 deletions(-) diff --git a/auth/auth.go b/auth/auth.go index 71a698b44..54562889f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -42,7 +42,7 @@ func PubKeyContext(next http.Handler) http.Handler { if token == "" { fmt.Println("[auth] no token") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -53,13 +53,13 @@ func PubKeyContext(next http.Handler) http.Handler { if err != nil { fmt.Println("Failed to parse JWT") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if claims.VerifyExpiresAt(time.Now().UnixNano(), true) { fmt.Println("Token has expired") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -73,7 +73,7 @@ func PubKeyContext(next http.Handler) http.Handler { if err != nil { fmt.Println(err) } - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -93,7 +93,7 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { if token == "" { fmt.Println("[auth] no token") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -103,20 +103,20 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { if err != nil { fmt.Println("Failed to parse JWT") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if claims.VerifyExpiresAt(time.Now().UnixNano(), true) { fmt.Println("Token has expired") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } pubkey := fmt.Sprintf("%v", claims["pubkey"]) if !IsFreePass() && !AdminCheck(pubkey) { fmt.Println("Not a super admin") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -130,13 +130,13 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { if err != nil { fmt.Println(err) } - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if !IsFreePass() && !AdminCheck(pubkey) { fmt.Println("Not a super admin : auth") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } @@ -149,17 +149,23 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { // ConnectionContext parses token for connection code func ConnectionCodeContext(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if r == nil { + http.Error(w, http.StatusText(500), http.StatusInternalServerError) + return + } + token := r.Header.Get("token") if token == "" { fmt.Println("[auth] no token") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } if token != config.Connection_Auth { fmt.Println("Not a super admin : auth") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), ContextKey, token) @@ -175,7 +181,7 @@ func CypressContext(next http.Handler) http.Handler { next.ServeHTTP(w, r.WithContext(ctx)) } else { fmt.Println("Endpoint is for testing only : test endpoint") - http.Error(w, http.StatusText(401), 401) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } }) diff --git a/auth/auth_test.go b/auth/auth_test.go index 6c763af7e..8326a8979 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -6,6 +6,8 @@ import ( "encoding/hex" "errors" "fmt" + "net/http" + "net/http/httptest" "strings" "testing" @@ -837,3 +839,135 @@ func TestSign(t *testing.T) { }) } } + +func TestConnectionCodeContext(t *testing.T) { + config.Connection_Auth = "valid_token" + + tests := []struct { + name string + token string + expectedStatus int + expectNextCall bool + }{ + { + name: "Valid Token in Header", + token: "valid_token", + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Invalid Token in Header", + token: "invalid_token", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Empty Token in Header", + token: "", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "No Token Header Present", + token: "", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Malformed Header", + token: "malformed_header", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Special Characters", + token: "special!@#token", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Whitespace", + token: " " + config.Connection_Auth + " ", + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Case Sensitivity in Token", + token: strings.ToUpper(config.Connection_Auth), + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.token != "" { + req.Header.Set("token", tt.token) + } + + rr := httptest.NewRecorder() + + handler := ConnectionCodeContext(next) + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + + assert.Equal(t, tt.expectNextCall, nextCalled) + }) + } + + t.Run("Null Request Object", func(t *testing.T) { + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + handler := ConnectionCodeContext(next) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, nil) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + + assert.False(t, nextCalled) + }) + + t.Run("Large Number of Requests", func(t *testing.T) { + + nextCalled := 0 + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled++ + w.WriteHeader(http.StatusOK) + }) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if i%2 == 0 { + req.Header.Set("token", "valid_token") + } else { + req.Header.Set("token", "invalid_token") + } + + rr := httptest.NewRecorder() + handler := ConnectionCodeContext(next) + handler.ServeHTTP(rr, req) + + if i%2 == 0 { + assert.Equal(t, http.StatusOK, rr.Code) + } else { + assert.Equal(t, http.StatusUnauthorized, rr.Code) + } + } + + assert.Equal(t, 500, nextCalled) + }) +}