diff --git a/auth/auth_test.go b/auth/auth_test.go index 8326a8979..a443cd641 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -10,10 +10,12 @@ import ( "net/http/httptest" "strings" "testing" + "time" btcec "github.com/btcsuite/btcd/btcec/v2" btcecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/config" "github.com/stretchr/testify/assert" ) @@ -971,3 +973,185 @@ func TestConnectionCodeContext(t *testing.T) { assert.Equal(t, 500, nextCalled) }) } + +func TestDecodeJwt(t *testing.T) { + config.InitConfig() + InitJwt() + + mockJwtKey := "testsecretkey" + config.JwtKey = mockJwtKey + + createToken := func(claims jwt.MapClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, _ := token.SignedString([]byte(mockJwtKey)) + return tokenString + } + + tests := []struct { + name string + token string + expectedClaims jwt.MapClaims + expectedError error + }{ + { + name: "Valid JWT Token", + token: createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }), + expectedClaims: jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }, + expectedError: nil, + }, + { + name: "Empty Token String", + token: "", + expectedClaims: nil, + expectedError: errors.New("token contains an invalid number of segments"), + }, + { + name: "Token with Only Header and Payload", + token: "header.payload", + expectedClaims: nil, + expectedError: errors.New("token contains an invalid number of segments"), + }, + { + name: "Token with Invalid Signature", + token: func() string { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenString, _ := token.SignedString([]byte("wrongkey")) + return tokenString + }(), + expectedClaims: nil, + expectedError: errors.New("signature is invalid"), + }, + { + name: "Malformed Token", + token: "randomstring", + expectedClaims: nil, + expectedError: errors.New("token contains an invalid number of segments"), + }, + { + name: "Token with Unsupported Algorithm", + token: func() string { + token := jwt.New(jwt.SigningMethodNone) + token.Claims = jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": time.Now().Add(time.Hour).Unix(), + } + tokenString, _ := token.SignedString(jwt.UnsafeAllowNoneSignatureType) + return tokenString + }(), + expectedClaims: nil, + expectedError: errors.New("'none' signature type is not allowed"), + }, + { + name: "Token with Expired Claims", + token: createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(-time.Hour).Unix()), + }), + expectedClaims: nil, + expectedError: errors.New("Token is expired"), + }, + { + name: "Token with Future Not Before (nbf) Claim", + token: createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "nbf": float64(time.Now().Add(time.Hour).Unix()), + }), + expectedClaims: nil, + expectedError: errors.New("Token is not valid yet"), + }, + { + name: "Token with Non-String Claims", + token: createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + "roles": []interface{}{"admin", "user"}, + }), + expectedClaims: jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + "roles": []interface{}{"admin", "user"}, + }, + expectedError: nil, + }, + { + name: "Null Request Object", + token: "null", + expectedClaims: nil, + expectedError: errors.New("token contains an invalid number of segments"), + }, + { + name: "Token with Missing Key", + token: func() string { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": time.Now().Add(time.Hour).Unix(), + }) + tokenString, _ := token.SignedString([]byte("differentkey")) + return tokenString + }(), + expectedClaims: nil, + expectedError: errors.New("signature is invalid"), + }, + { + name: "Token with Additional Unrecognized Claims", + token: createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + "extra": "value", + }), + expectedClaims: jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + "extra": "value", + }, + expectedError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims, err := DecodeJwt(tt.token) + + if tt.expectedError != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } else { + assert.NoError(t, err) + if claims != nil && tt.expectedClaims != nil { + assert.Equal(t, tt.expectedClaims["pubkey"], claims["pubkey"]) + if tt.expectedClaims["roles"] != nil { + assert.ElementsMatch(t, tt.expectedClaims["roles"], claims["roles"]) + } + if tt.expectedClaims["extra"] != nil { + assert.Equal(t, tt.expectedClaims["extra"], claims["extra"]) + } + if tt.expectedClaims["data"] != nil { + assert.Equal(t, tt.expectedClaims["data"], claims["data"]) + } + } + } + }) + } + + t.Run("Large Number of Requests", func(t *testing.T) { + validToken := createToken(jwt.MapClaims{ + "pubkey": "testpubkey", + "exp": float64(time.Now().Add(time.Hour).Unix()), + }) + + for i := 0; i < 1000; i++ { + claims, err := DecodeJwt(validToken) + assert.NoError(t, err) + assert.Equal(t, "testpubkey", claims["pubkey"]) + } + }) +}