diff --git a/handlers/bounty.go b/handlers/bounty.go index 885391c93..f6175806e 100644 --- a/handlers/bounty.go +++ b/handlers/bounty.go @@ -21,14 +21,16 @@ import ( type bountyHandler struct { httpClient HttpClient db db.Database + getSocketConnections func(host string) (db.Client, error) generateBountyResponse func(bounties []db.Bounty) []db.BountyResponse } -func NewBountyHandler(httpClient HttpClient, db db.Database) *bountyHandler { +func NewBountyHandler(httpClient HttpClient, database db.Database) *bountyHandler { return &bountyHandler{ - httpClient: httpClient, - db: db, + httpClient: httpClient, + db: database, + getSocketConnections: db.Store.GetSocketConnections, } } @@ -441,6 +443,7 @@ func (h *bountyHandler) MakeBountyPayment(w http.ResponseWriter, r *http.Request if bounty.Paid { w.WriteHeader(http.StatusMethodNotAllowed) json.NewEncoder(w).Encode("Bounty has already been paid") + return } // check if user is the admin of the organization @@ -479,11 +482,10 @@ func (h *bountyHandler) MakeBountyPayment(w http.ResponseWriter, r *http.Request jsonBody := []byte(bodyData) - client := &http.Client{} req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, err := client.Do(req) + res, err := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -513,17 +515,17 @@ func (h *bountyHandler) MakeBountyPayment(w http.ResponseWriter, r *http.Request Status: true, PaymentType: "payment", } - db.DB.AddPaymentHistory(paymentHistory) + h.db.AddPaymentHistory(paymentHistory) bounty.Paid = true bounty.PaidDate = &now bounty.CompletionDate = &now - db.DB.UpdateBounty(bounty) + h.db.UpdateBounty(bounty) msg["msg"] = "keysend_success" msg["invoice"] = "" - socket, err := db.Store.GetSocketConnections(request.Websocket_token) + socket, err := h.getSocketConnections(request.Websocket_token) if err == nil { socket.Conn.WriteJSON(msg) } @@ -531,7 +533,7 @@ func (h *bountyHandler) MakeBountyPayment(w http.ResponseWriter, r *http.Request msg["msg"] = "keysend_error" msg["invoice"] = "" - socket, err := db.Store.GetSocketConnections(request.Websocket_token) + socket, err := h.getSocketConnections(request.Websocket_token) if err == nil { socket.Conn.WriteJSON(msg) } diff --git a/handlers/bounty_test.go b/handlers/bounty_test.go index 410deb099..19277a34f 100644 --- a/handlers/bounty_test.go +++ b/handlers/bounty_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/stakwork/sphinx-tribes/utils" "io" "net/http" "net/http/httptest" @@ -16,6 +15,9 @@ import ( "testing" "time" + "github.com/gorilla/websocket" + "github.com/stakwork/sphinx-tribes/utils" + "github.com/go-chi/chi" "github.com/lib/pq" "github.com/stakwork/sphinx-tribes/auth" @@ -1122,10 +1124,45 @@ func TestGetAllBounties(t *testing.T) { }) } +func MockNewWSServer(t *testing.T) (*httptest.Server, *websocket.Conn) { + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var upgrader = websocket.Upgrader{} + + upgrader.CheckOrigin = func(r *http.Request) bool { return true } + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println("upgrade:", err) + return + } + defer ws.Close() + })) + wsURL := "ws" + strings.TrimPrefix(s.URL, "http") + + ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatal(err) + } + + return s, ws +} + func TestMakeBountyPayment(t *testing.T) { ctx := context.Background() - mockDb := dbMocks.NewDatabase(t) - mockHttpClient := mocks.NewHttpClient(t) + mockDb := &dbMocks.Database{} + mockHttpClient := &mocks.HttpClient{} + mockGetSocketConnections := func(host string) (db.Client, error) { + s, ws := MockNewWSServer(t) + defer s.Close() + defer ws.Close() + + mockClient := db.Client{ + Host: "mocked_host", + Conn: ws, + } + + return mockClient, nil + } bHandler := NewBountyHandler(mockHttpClient, mockDb) unauthorizedCtx := context.WithValue(ctx, auth.ContextKey, "") @@ -1134,6 +1171,14 @@ func TestMakeBountyPayment(t *testing.T) { var mutex sync.Mutex var processingTimes []time.Time + bountyID := uint(1) + bounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + } + t.Run("mutex lock ensures sequential access", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mutex.Lock() @@ -1183,6 +1228,7 @@ func TestMakeBountyPayment(t *testing.T) { }) t.Run("405 when trying to pay an already-paid bounty", func(t *testing.T) { + mockDb.ExpectedCalls = nil mockDb.On("GetBounty", mock.AnythingOfType("uint")).Return(db.Bounty{ ID: 1, Price: 1000, @@ -1191,12 +1237,6 @@ func TestMakeBountyPayment(t *testing.T) { Paid: true, }, nil) - mockDb.On("UserHasAccess", "valid-key", "org-1", db.PayBounty).Return(true) - mockDb.On("GetOrganizationBudget", "org-1").Return(db.BountyBudget{TotalBudget: 1000}, nil) - mockDb.On("GetPersonByPubkey", "assignee-1").Return(db.Person{ - OwnerPubKey: "assignee-1", - }, nil) - r := chi.NewRouter() r.Post("/gobounties/pay/{id}", bHandler.MakeBountyPayment) @@ -1270,6 +1310,103 @@ func TestMakeBountyPayment(t *testing.T) { }) + t.Run("Should test that a successful WebSocket message is sent if the payment is successful", func(t *testing.T) { + mockDb.ExpectedCalls = nil + bHandler.getSocketConnections = mockGetSocketConnections + + now := time.Now() + expectedBounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + Paid: true, + PaidDate: &now, + CompletionDate: &now, + } + + mockDb.On("GetBounty", bountyID).Return(bounty, nil) + mockDb.On("UserHasAccess", "valid-key", bounty.OrgUuid, db.PayBounty).Return(true) + mockDb.On("GetOrganizationBudget", bounty.OrgUuid).Return(db.BountyBudget{TotalBudget: 2000}, nil) + mockDb.On("GetPersonByPubkey", bounty.Assignee).Return(db.Person{OwnerPubKey: "assignee-1", OwnerRouteHint: "OwnerRouteHint"}, nil) + mockDb.On("AddPaymentHistory", mock.AnythingOfType("db.PaymentHistory")).Return(db.PaymentHistory{ID: 1}) + mockDb.On("UpdateBounty", mock.AnythingOfType("db.Bounty")).Run(func(args mock.Arguments) { + updatedBounty := args.Get(0).(db.Bounty) + assert.True(t, updatedBounty.Paid) + assert.NotNil(t, updatedBounty.PaidDate) + assert.NotNil(t, updatedBounty.CompletionDate) + }).Return(expectedBounty, nil).Once() + + expectedUrl := fmt.Sprintf("%s/payment", config.RelayUrl) + expectedBody := `{"amount": 1000, "destination_key": "assignee-1", "route_hint": "OwnerRouteHint", "text": "memotext added for notification"}` + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "sumAmount": "1"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + bodyByt, _ := io.ReadAll(req.Body) + return req.Method == http.MethodPost && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey && expectedBody == string(bodyByt) + })).Return(&http.Response{ + StatusCode: 200, + Body: r, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/gobounties/pay/{id}", bHandler.MakeBountyPayment) + + requestBody := bytes.NewBuffer([]byte("{}")) + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/gobounties/pay/1", requestBody) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + mockDb.AssertExpectations(t) + mockHttpClient.AssertExpectations(t) + }) + + t.Run("Should test that an error WebSocket message is sent if the payment fails", func(t *testing.T) { + mockDb2 := &dbMocks.Database{} + mockHttpClient2 := &mocks.HttpClient{} + mockDb2.ExpectedCalls = nil + + bHandler2 := NewBountyHandler(mockHttpClient2, mockDb2) + bHandler2.getSocketConnections = mockGetSocketConnections + + mockDb2.On("GetBounty", bountyID).Return(bounty, nil) + mockDb2.On("UserHasAccess", "valid-key", bounty.OrgUuid, db.PayBounty).Return(true) + mockDb2.On("GetOrganizationBudget", bounty.OrgUuid).Return(db.BountyBudget{TotalBudget: 2000}, nil) + mockDb2.On("GetPersonByPubkey", bounty.Assignee).Return(db.Person{OwnerPubKey: "assignee-1", OwnerRouteHint: "OwnerRouteHint"}, nil) + + expectedUrl := fmt.Sprintf("%s/payment", config.RelayUrl) + expectedBody := `{"amount": 1000, "destination_key": "assignee-1", "route_hint": "OwnerRouteHint", "text": "memotext added for notification"}` + + r := io.NopCloser(bytes.NewReader([]byte(`"internal server error"`))) + mockHttpClient2.On("Do", mock.MatchedBy(func(req *http.Request) bool { + bodyByt, _ := io.ReadAll(req.Body) + return req.Method == http.MethodPost && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey && expectedBody == string(bodyByt) + })).Return(&http.Response{ + StatusCode: 500, + Body: r, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/gobounties/pay/{id}", bHandler2.MakeBountyPayment) + + requestBody := bytes.NewBuffer([]byte("{}")) + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/gobounties/pay/1", requestBody) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + mockDb2.AssertExpectations(t) + mockHttpClient2.AssertExpectations(t) + }) } func TestBountyBudgetWithdraw(t *testing.T) {