Skip to content

Commit

Permalink
Merge pull request #2115 from sophieturner0/Unit-Test-ConnectionCodeC…
Browse files Browse the repository at this point in the history
…ontext

[Unit Tests] - ConnectionCodeContext
  • Loading branch information
elraphty authored Dec 5, 2024
2 parents e14d31c + a0cf034 commit 5e13ad9
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 13 deletions.
32 changes: 19 additions & 13 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func PubKeyContext(next http.Handler) http.Handler {

if token == "" {
fmt.Println("[auth] no token")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -53,13 +53,13 @@ func PubKeyContext(next http.Handler) http.Handler {

if err != nil {
fmt.Println("Failed to parse JWT")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

if claims.VerifyExpiresAt(time.Now().UnixNano(), true) {
fmt.Println("Token has expired")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -73,7 +73,7 @@ func PubKeyContext(next http.Handler) http.Handler {
if err != nil {
fmt.Println(err)
}
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -93,7 +93,7 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {

if token == "" {
fmt.Println("[auth] no token")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -103,20 +103,20 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {

if err != nil {
fmt.Println("Failed to parse JWT")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

if claims.VerifyExpiresAt(time.Now().UnixNano(), true) {
fmt.Println("Token has expired")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

pubkey := fmt.Sprintf("%v", claims["pubkey"])
if !IsFreePass() && !AdminCheck(pubkey) {
fmt.Println("Not a super admin")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -130,13 +130,13 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {
if err != nil {
fmt.Println(err)
}
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

if !IsFreePass() && !AdminCheck(pubkey) {
fmt.Println("Not a super admin : auth")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

Expand All @@ -149,17 +149,23 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler {
// ConnectionContext parses token for connection code
func ConnectionCodeContext(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if r == nil {
http.Error(w, http.StatusText(500), http.StatusInternalServerError)
return
}

token := r.Header.Get("token")

if token == "" {
fmt.Println("[auth] no token")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}

if token != config.Connection_Auth {
fmt.Println("Not a super admin : auth")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), ContextKey, token)
Expand All @@ -175,7 +181,7 @@ func CypressContext(next http.Handler) http.Handler {
next.ServeHTTP(w, r.WithContext(ctx))
} else {
fmt.Println("Endpoint is for testing only : test endpoint")
http.Error(w, http.StatusText(401), 401)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
})
Expand Down
134 changes: 134 additions & 0 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"encoding/hex"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"

Expand Down Expand Up @@ -837,3 +839,135 @@ func TestSign(t *testing.T) {
})
}
}

func TestConnectionCodeContext(t *testing.T) {
config.Connection_Auth = "valid_token"

tests := []struct {
name string
token string
expectedStatus int
expectNextCall bool
}{
{
name: "Valid Token in Header",
token: "valid_token",
expectedStatus: http.StatusOK,
expectNextCall: true,
},
{
name: "Invalid Token in Header",
token: "invalid_token",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "Empty Token in Header",
token: "",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "No Token Header Present",
token: "",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "Malformed Header",
token: "malformed_header",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "Token with Special Characters",
token: "special!@#token",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "Token with Whitespace",
token: " " + config.Connection_Auth + " ",
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
{
name: "Case Sensitivity in Token",
token: strings.ToUpper(config.Connection_Auth),
expectedStatus: http.StatusUnauthorized,
expectNextCall: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})

req := httptest.NewRequest(http.MethodGet, "/", nil)
if tt.token != "" {
req.Header.Set("token", tt.token)
}

rr := httptest.NewRecorder()

handler := ConnectionCodeContext(next)
handler.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatus, rr.Code)

assert.Equal(t, tt.expectNextCall, nextCalled)
})
}

t.Run("Null Request Object", func(t *testing.T) {

nextCalled := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
})

handler := ConnectionCodeContext(next)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, nil)

assert.Equal(t, http.StatusInternalServerError, rr.Code)

assert.False(t, nextCalled)
})

t.Run("Large Number of Requests", func(t *testing.T) {

nextCalled := 0
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled++
w.WriteHeader(http.StatusOK)
})

for i := 0; i < 1000; i++ {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if i%2 == 0 {
req.Header.Set("token", "valid_token")
} else {
req.Header.Set("token", "invalid_token")
}

rr := httptest.NewRecorder()
handler := ConnectionCodeContext(next)
handler.ServeHTTP(rr, req)

if i%2 == 0 {
assert.Equal(t, http.StatusOK, rr.Code)
} else {
assert.Equal(t, http.StatusUnauthorized, rr.Code)
}
}

assert.Equal(t, 500, nextCalled)
})
}

0 comments on commit 5e13ad9

Please sign in to comment.