diff --git a/auth/auth.go b/auth/auth.go index 54562889f..ff41839b3 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -86,6 +86,12 @@ func PubKeyContext(next http.Handler) http.Handler { // PubKeyContext parses pukey from signed timestamp func PubKeyContextSuperAdmin(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if r == nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + token := r.URL.Query().Get("token") if token == "" { token = r.Header.Get("x-jwt") diff --git a/auth/auth_test.go b/auth/auth_test.go index 8326a8979..be7b2d985 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -3,6 +3,7 @@ package auth import ( "bytes" "encoding/base64" + "encoding/binary" "encoding/hex" "errors" "fmt" @@ -10,6 +11,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" btcec "github.com/btcsuite/btcd/btcec/v2" btcecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" @@ -971,3 +973,261 @@ func TestConnectionCodeContext(t *testing.T) { assert.Equal(t, 500, nextCalled) }) } + +func TestPubKeyContextSuperAdmin(t *testing.T) { + + config.InitConfig() + InitJwt() + + privKey, err := btcec.NewPrivateKey() + assert.NoError(t, err) + expectedPubKeyHex := hex.EncodeToString(privKey.PubKey().SerializeCompressed()) + + config.SuperAdmins = []string{expectedPubKeyHex} + config.AdminDevFreePass = "freepass" + originalSuperAdmins := config.SuperAdmins + originalAdminDevFreePass := config.AdminDevFreePass + + createValidJWT := func(pubkey string, expireHours int) string { + claims := map[string]interface{}{ + "pubkey": pubkey, + "exp": time.Now().Add(time.Hour * time.Duration(expireHours)).Unix(), + } + _, tokenString, _ := TokenAuth.Encode(claims) + return tokenString + } + + createValidTribeToken := func(pubkey string) string { + timeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(timeBuf, uint32(time.Now().Unix())) + msg := append(signedMsgPrefix, timeBuf...) + digest := chainhash.DoubleHashB(msg) + sig, err := btcecdsa.SignCompact(privKey, digest, true) + assert.NoError(t, err) + token := append(timeBuf, sig...) + return base64.URLEncoding.EncodeToString(token) + } + + tests := []struct { + name string + setupToken func(r *http.Request) + setupConfig func() + expectedStatus int + expectNextCall bool + }{ + { + name: "Valid JWT Token with Super Admin Privileges", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, 24)) + }, + setupConfig: func() { + config.SuperAdmins = []string{expectedPubKeyHex} + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Valid Tribe UUID Token with Super Admin Privileges", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidTribeToken(expectedPubKeyHex)) + }, + setupConfig: func() { + config.SuperAdmins = []string{expectedPubKeyHex} + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Empty Token in Request", + setupToken: func(r *http.Request) {}, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Expired JWT Token", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, -1)) + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Invalid JWT Token Format", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "invalid.jwt.token") + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Invalid Tribe UUID Token", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "invalid-tribe-token") + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "JWT Token with Non-Super Admin Pubkey", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT("non-admin-pubkey", 24)) + }, + setupConfig: func() { + config.SuperAdmins = []string{expectedPubKeyHex} + config.AdminDevFreePass = "" + config.AdminStrings = "non-empty" + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Tribe UUID Token with Non-Super Admin Pubkey", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "non.admin.tribe.uuid") + }, + setupConfig: func() { + config.SuperAdmins = []string{expectedPubKeyHex} + config.AdminDevFreePass = "" + config.AdminStrings = "non-empty" + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token in Both Query and Header", + setupToken: func(r *http.Request) { + r.URL.RawQuery = "token=" + createValidJWT(expectedPubKeyHex, 24) + }, + setupConfig: func() { + config.SuperAdmins = []string{expectedPubKeyHex} + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Free Pass Configuration", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT("any-pubkey", 24)) + }, + setupConfig: func() { + config.SuperAdmins = []string{config.AdminDevFreePass} + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Malformed Token in Header", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "malformed token") + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Special Characters", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "special!@#token") + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Whitespace", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", " "+createValidJWT(expectedPubKeyHex, 24)+" ") + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Case Sensitivity in Token", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", strings.ToUpper(createValidJWT(expectedPubKeyHex, 24))) + }, + setupConfig: func() {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + config.SuperAdmins = originalSuperAdmins + config.AdminDevFreePass = originalAdminDevFreePass + + if tt.setupConfig != nil { + tt.setupConfig() + } + + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + tt.setupToken(req) + + rr := httptest.NewRecorder() + + handler := PubKeyContextSuperAdmin(next) + handler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + assert.Equal(t, tt.expectNextCall, nextCalled) + }) + } + + t.Run("Null Request Object", func(t *testing.T) { + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + + handler := PubKeyContextSuperAdmin(next) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, nil) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + assert.False(t, nextCalled) + }) + + t.Run("Large Number of Requests", func(t *testing.T) { + nextCalled := 0 + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled++ + w.WriteHeader(http.StatusOK) + }) + + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if i%2 == 0 { + req.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, 24)) + } else { + req.Header.Set("x-jwt", createValidJWT("non-admin-pubkey", 24)) + } + + rr := httptest.NewRecorder() + handler := PubKeyContextSuperAdmin(next) + handler.ServeHTTP(rr, req) + + if i%2 == 0 { + assert.Equal(t, http.StatusOK, rr.Code) + } else { + assert.Equal(t, http.StatusUnauthorized, rr.Code) + } + } + + assert.Equal(t, 500, nextCalled) + }) + +}