diff --git a/auth/auth_test.go b/auth/auth_test.go index c5b7bfc84..466716bb2 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -1864,3 +1864,162 @@ func TestVerifyTribeUUID(t *testing.T) { }) } } + +func TestPubKeyContext(t *testing.T) { + config.InitConfig() + InitJwt() + privKey, err := btcec.NewPrivateKey() + assert.NoError(t, err) + expectedPubKeyHex := hex.EncodeToString(privKey.PubKey().SerializeCompressed()) + createValidJWT := func(pubkey string, expireHours int) string { + claims := map[string]interface{}{ + "pubkey": pubkey, + "exp": time.Now().Add(time.Hour * time.Duration(expireHours)).Unix(), + } + _, tokenString, _ := TokenAuth.Encode(claims) + return tokenString + } + createValidTribeToken := func(_ string) string { + timeBuf := make([]byte, 4) + binary.BigEndian.PutUint32(timeBuf, uint32(time.Now().Unix())) + msg := append(signedMsgPrefix, timeBuf...) + digest := chainhash.DoubleHashB(msg) + sig, err := btcecdsa.SignCompact(privKey, digest, true) + assert.NoError(t, err) + token := append(timeBuf, sig...) + return base64.URLEncoding.EncodeToString(token) + } + tests := []struct { + name string + setupToken func(r *http.Request) + expectedStatus int + expectNextCall bool + }{ + { + name: "Valid JWT Token in Header", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, 24)) + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Valid JWT Token in Query", + setupToken: func(r *http.Request) { + r.URL.RawQuery = "token=" + createValidJWT(expectedPubKeyHex, 24) + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Valid Tribe UUID Token in Header", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidTribeToken(expectedPubKeyHex)) + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Valid Tribe UUID Token in Query", + setupToken: func(r *http.Request) { + r.URL.RawQuery = "token=" + createValidTribeToken(expectedPubKeyHex) + }, + expectedStatus: http.StatusOK, + expectNextCall: true, + }, + { + name: "Empty Token", + setupToken: func(r *http.Request) {}, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Expired JWT Token", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, -1)) + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Invalid JWT Format", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "invalid.jwt.token") + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Invalid Tribe UUID Format", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "invalid-tribe-token") + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Malformed Token", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "malformed token") + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Special Characters", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", "special!@#token") + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + { + name: "Token with Whitespace", + setupToken: func(r *http.Request) { + r.Header.Set("x-jwt", " "+createValidJWT(expectedPubKeyHex, 24)+" ") + }, + expectedStatus: http.StatusUnauthorized, + expectNextCall: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nextCalled := false + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + tt.setupToken(req) + rr := httptest.NewRecorder() + handler := PubKeyContext(next) + handler.ServeHTTP(rr, req) + assert.Equal(t, tt.expectedStatus, rr.Code) + assert.Equal(t, tt.expectNextCall, nextCalled) + }) + } + t.Run("Large Number of Requests", func(t *testing.T) { + nextCalled := 0 + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled++ + w.WriteHeader(http.StatusOK) + }) + for i := 0; i < 1000; i++ { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if i%2 == 0 { + req.Header.Set("x-jwt", createValidJWT(expectedPubKeyHex, 24)) + } else { + req.Header.Set("x-jwt", "invalid-token") + } + rr := httptest.NewRecorder() + handler := PubKeyContext(next) + handler.ServeHTTP(rr, req) + if i%2 == 0 { + assert.Equal(t, http.StatusOK, rr.Code) + } else { + assert.Equal(t, http.StatusUnauthorized, rr.Code) + } + } + assert.Equal(t, 500, nextCalled) + }) +}