diff --git a/auth/auth.go b/auth/auth.go index 7a80df7fc..71a698b44 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -290,6 +290,11 @@ func DecodeJwt(token string) (jwt.MapClaims, error) { } func EncodeJwt(pubkey string) (string, error) { + + if pubkey == "" || strings.ContainsAny(pubkey, "!@#$%^&*()") { + return "", errors.New("invalid public key") + } + exp := ExpireInHours(24 * 7) claims := jwt.MapClaims{ diff --git a/auth/auth_test.go b/auth/auth_test.go index 4ea7d55ed..d806ea928 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/hex" "errors" + "fmt" "strings" "testing" @@ -246,6 +247,112 @@ func TestIsFreePass(t *testing.T) { } } +func generateLargePayload() map[string]interface{} { + payload := make(map[string]interface{}) + for i := 0; i < 1000; i++ { + payload[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i) + } + return payload +} + +func TestEncodeJwt(t *testing.T) { + + config.InitConfig() + InitJwt() + + tests := []struct { + name string + publicKey string + payload interface{} + expectError bool + }{ + { + name: "Valid Public Key and Payload", + publicKey: "validPublicKey", + payload: map[string]interface{}{"user": "testUser"}, + expectError: false, + }, + { + name: "Valid Public Key with Minimal Payload", + publicKey: "validPublicKey", + payload: map[string]interface{}{"id": 1}, + expectError: false, + }, + { + name: "Empty Payload", + publicKey: "validPublicKey", + payload: map[string]interface{}{}, + expectError: false, + }, + { + name: "Maximum Size Payload", + publicKey: "validPublicKey", + payload: generateLargePayload(), + expectError: false, + }, + { + name: "Boundary Public Key Length", + publicKey: "a", + payload: map[string]interface{}{"user": "testUser"}, + expectError: false, + }, + { + name: "Invalid Public Key", + publicKey: "invalidPublicKey!", + payload: map[string]interface{}{"user": "testUser"}, + expectError: true, + }, + { + name: "Null Public Key", + publicKey: "", + payload: map[string]interface{}{"user": "testUser"}, + expectError: true, + }, + { + name: "Expired Payload", + publicKey: "validPublicKey", + payload: map[string]interface{}{"exp": -1}, + expectError: false, + }, + { + name: "Future Expiration Date", + publicKey: "validPublicKey", + payload: map[string]interface{}{"exp": 9999999999}, + expectError: false, + }, + { + name: "Payload with Special Characters", + publicKey: "validPublicKey", + payload: map[string]interface{}{"emoji": "😀"}, + expectError: false, + }, + { + name: "Payload with Reserved JWT Claims", + publicKey: "validPublicKey", + payload: map[string]interface{}{"iss": "issuer", "sub": "subject"}, + expectError: false, + }, + { + name: "Payload with Mixed Data Types", + publicKey: "validPublicKey", + payload: map[string]interface{}{"string": "value", "number": 123, "boolean": true}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jwt, err := EncodeJwt(tt.publicKey) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, jwt) + } + }) + } +} + func TestVerifyAndExtract(t *testing.T) { privKey, err := btcec.NewPrivateKey() diff --git a/handlers/ticket.go b/handlers/ticket.go index 781490df6..2a5b0fafd 100644 --- a/handlers/ticket.go +++ b/handlers/ticket.go @@ -15,6 +15,7 @@ import ( "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/utils" + "github.com/stakwork/sphinx-tribes/websocket" ) type ticketHandler struct { @@ -140,6 +141,31 @@ func (th *ticketHandler) UpdateTicket(w http.ResponseWriter, r *http.Request) { return } + if updateRequest.Metadata.Source == "websocket" && updateRequest.Metadata.ID != "" { + ticketMsg := websocket.TicketMessage{ + BroadcastType: "direct", + SourceSessionID: updateRequest.Metadata.ID, + Message: fmt.Sprintf("Hive has successfully updated your ticket %s", updateRequest.Ticket.Name), + Action: "message", + TicketDetails: websocket.TicketData{ + FeatureUUID: updateRequest.Ticket.FeatureUUID, + PhaseUUID: updateRequest.Ticket.PhaseUUID, + TicketUUID: updateRequest.Ticket.UUID.String(), + TicketDescription: updateRequest.Ticket.Description, + }, + } + + if err := websocket.WebsocketPool.SendTicketMessage(ticketMsg); err != nil { + log.Printf("Failed to send websocket message: %v", err) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "ticket": updatedTicket, + "websocket_error": err.Error(), + }) + return + } + } + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(updatedTicket) } @@ -380,6 +406,31 @@ func (th *ticketHandler) PostTicketDataToStakwork(w http.ResponseWriter, r *http return } + if ticketRequest.Metadata.Source == "websocket" && ticketRequest.Metadata.ID != "" { + ticketMsg := websocket.TicketMessage{ + BroadcastType: "direct", + SourceSessionID: ticketRequest.Metadata.ID, + Message: fmt.Sprintf("I have your updates and I'm rewriting ticket %s now", ticketRequest.Ticket.Name), + Action: "message", + TicketDetails: websocket.TicketData{ + FeatureUUID: ticketRequest.Ticket.FeatureUUID, + PhaseUUID: ticketRequest.Ticket.PhaseUUID, + TicketUUID: ticketRequest.Ticket.UUID.String(), + TicketDescription: ticketRequest.Ticket.Description, + }, + } + + if err := websocket.WebsocketPool.SendTicketMessage(ticketMsg); err != nil { + log.Printf("Failed to send websocket message: %v", err) + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "ticket": ticketRequest, + "websocket_error": err.Error(), + }) + return + } + } + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(TicketResponse{ Success: true, diff --git a/websocket/client.go b/websocket/client.go index 98f9c56b3..6260c2b16 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -26,6 +26,22 @@ type Message struct { Body string `json:"body"` } +type TicketMessage struct { + Type int `json:"type"` + BroadcastType string `json:"broadcastType"` + SourceSessionID string `json:"sourceSessionID"` + Message string `json:"message"` + Action string `json:"action"` + TicketDetails TicketData `json:"ticketDetails"` +} + +type TicketData struct { + FeatureUUID string `json:"featureUUID"` + PhaseUUID string `json:"phaseUUID"` + TicketUUID string `json:"ticketUUID"` + TicketDescription string `json:"ticketDescription"` +} + func (c *Client) Read() { defer func() { c.Pool.Unregister <- c diff --git a/websocket/pool.go b/websocket/pool.go index 266e81421..3270e98ae 100644 --- a/websocket/pool.go +++ b/websocket/pool.go @@ -58,3 +58,15 @@ func (pool *Pool) Start() { } } } + +func (pool *Pool) SendTicketMessage(message TicketMessage) error { + if message.BroadcastType == "direct" { + + if client, ok := pool.Clients[message.SourceSessionID]; ok { + return client.Client.Conn.WriteJSON(message) + } + return fmt.Errorf("client not found: %s", message.SourceSessionID) + } + + return nil +}