diff --git a/handlers/bounty.go b/handlers/bounty.go index 885391c93..5a1e5d13a 100644 --- a/handlers/bounty.go +++ b/handlers/bounty.go @@ -606,15 +606,14 @@ func formatPayError(errorMsg string) db.InvoicePayError { } } -func GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { +func (h *bountyHandler) GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { url := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, payment_request) - client := &http.Client{} req, err := http.NewRequest(http.MethodGet, url, nil) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -693,9 +692,9 @@ func (h *bountyHandler) PayLightningInvoice(payment_request string) (db.InvoiceP } } -func GetInvoiceData(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) GetInvoiceData(w http.ResponseWriter, r *http.Request) { paymentRequest := chi.URLParam(r, "paymentRequest") - invoiceData, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceData, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -707,7 +706,7 @@ func GetInvoiceData(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(invoiceData) } -func PollInvoice(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) PollInvoice(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) paymentRequest := chi.URLParam(r, "paymentRequest") @@ -719,7 +718,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { return } - invoiceRes, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceRes, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -729,14 +728,14 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { // Todo if an invoice is settled - invoice := db.DB.GetInvoice(paymentRequest) - invData := db.DB.GetUserInvoiceData(paymentRequest) - dbInvoice := db.DB.GetInvoice(paymentRequest) + invoice := h.db.GetInvoice(paymentRequest) + invData := h.db.GetUserInvoiceData(paymentRequest) + dbInvoice := h.db.GetInvoice(paymentRequest) // Make any change only if the invoice has not been settled if !dbInvoice.Status { if invoice.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(invoice) + h.db.AddAndUpdateBudget(invoice) } else if invoice.Type == "KEYSEND" { url := fmt.Sprintf("%s/payment", config.RelayUrl) @@ -746,12 +745,11 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { jsonBody := []byte(bodyData) - client := &http.Client{} req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -767,13 +765,13 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { keysendRes := db.KeysendSuccess{} err = json.Unmarshal(body, &keysendRes) - bounty, err := db.DB.GetBountyByCreated(uint(invData.Created)) + bounty, err := h.db.GetBountyByCreated(uint(invData.Created)) if err == nil { bounty.Paid = true } - db.DB.UpdateBounty(bounty) + h.db.UpdateBounty(bounty) } else { // Unmarshal result keysendError := db.KeysendError{} @@ -782,7 +780,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { } } // Update the invoice status - db.DB.UpdateInvoice(paymentRequest) + h.db.UpdateInvoice(paymentRequest) } } diff --git a/handlers/bounty_test.go b/handlers/bounty_test.go index 410deb099..8a9faad20 100644 --- a/handlers/bounty_test.go +++ b/handlers/bounty_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/stakwork/sphinx-tribes/utils" "io" "net/http" "net/http/httptest" @@ -16,6 +15,8 @@ import ( "testing" "time" + "github.com/stakwork/sphinx-tribes/utils" + "github.com/go-chi/chi" "github.com/lib/pq" "github.com/stakwork/sphinx-tribes/auth" @@ -1431,3 +1432,121 @@ func TestBountyBudgetWithdraw(t *testing.T) { mockHttpClient.AssertCalled(t, "Do", mock.AnythingOfType("*http.Request")) }) } + +func TestPollInvoice(t *testing.T) { + ctx := context.Background() + mockDb := &dbMocks.Database{} + mockHttpClient := &mocks.HttpClient{} + bHandler := NewBountyHandler(mockHttpClient, mockDb) + + unauthorizedCtx := context.WithValue(ctx, auth.ContextKey, "") + authorizedCtx := context.WithValue(ctx, auth.ContextKey, "valid-key") + + t.Run("Should test that a 401 error is returned if a user is unauthorized", func(t *testing.T) { + r := chi.NewRouter() + r.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(unauthorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code, "Expected 401 error if a user is unauthorized") + }) + + t.Run("Should test that a 403 error is returned if there is an invoice error", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": false, "error": "Internel server error"}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 500, + Body: r, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code, "Expected 403 error if there is an invoice error") + mockHttpClient.AssertExpectations(t) + }) + + t.Run("Should mock relay payment is successful update the bounty associated with the invoice and set the paid as true", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "settled": true, "payment_request": "1", "payment_hash": "payment_hash", "preimage": "preimage", "Amount": "1000"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 200, + Body: r, + }, nil).Once() + + bountyID := uint(1) + bounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + } + + now := time.Now() + expectedBounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + Paid: true, + PaidDate: &now, + CompletionDate: &now, + } + + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Type: "KEYSEND"}) + mockDb.On("GetUserInvoiceData", "1").Return(db.UserInvoiceData{Amount: 1000, UserPubkey: "UserPubkey", RouteHint: "RouteHint", Created: 1234}) + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Status: false}) + mockDb.On("GetBountyByCreated", uint(1234)).Return(bounty, nil) + mockDb.On("UpdateBounty", mock.AnythingOfType("db.Bounty")).Run(func(args mock.Arguments) { + updatedBounty := args.Get(0).(db.Bounty) + assert.True(t, updatedBounty.Paid) + }).Return(expectedBounty, nil).Once() + mockDb.On("UpdateInvoice", "1").Return(db.InvoiceList{}).Once() + + expectedPaymentUrl := fmt.Sprintf("%s/payment", config.RelayUrl) + expectedPaymentBody := `{"amount": 1000, "destination_key": "UserPubkey", "route_hint": "RouteHint", "text": "memotext added for notification"}` + + r2 := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "sumAmount": "1"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + bodyByt, _ := io.ReadAll(req.Body) + return req.Method == http.MethodPost && expectedPaymentUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey && expectedPaymentBody == string(bodyByt) + })).Return(&http.Response{ + StatusCode: 200, + Body: r2, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + mockHttpClient.AssertExpectations(t) + }) +} diff --git a/handlers/organizations.go b/handlers/organizations.go index 87a5dcdfe..7022c7eb5 100644 --- a/handlers/organizations.go +++ b/handlers/organizations.go @@ -18,6 +18,7 @@ import ( type organizationHandler struct { db db.Database generateBountyHandler func(bounties []db.Bounty) []db.BountyResponse + getLightningInvoice func(payment_request string) (db.InvoiceResult, db.InvoiceError) } func NewOrganizationHandler(db db.Database) *organizationHandler { @@ -25,6 +26,7 @@ func NewOrganizationHandler(db db.Database) *organizationHandler { return &organizationHandler{ db: db, generateBountyHandler: bHandler.GenerateBountyResponse, + getLightningInvoice: bHandler.GetLightningInvoice, } } @@ -630,7 +632,7 @@ func GetPaymentHistory(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(paymentHistoryData) } -func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { +func (oh *organizationHandler) PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) uuid := chi.URLParam(r, "uuid") @@ -641,10 +643,10 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { return } - orgInvoices := db.DB.GetOrganizationInvoices(uuid) + orgInvoices := oh.db.GetOrganizationInvoices(uuid) for _, inv := range orgInvoices { - invoiceRes, invoiceErr := GetLightningInvoice(inv.PaymentRequest) + invoiceRes, invoiceErr := oh.getLightningInvoice(inv.PaymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -654,9 +656,9 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { if !inv.Status && inv.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(inv) + oh.db.AddAndUpdateBudget(inv) // Update the invoice status - db.DB.UpdateInvoice(inv.PaymentRequest) + oh.db.UpdateInvoice(inv.PaymentRequest) } } } diff --git a/routes/bounty.go b/routes/bounty.go index f25e6050e..6e5049747 100644 --- a/routes/bounty.go +++ b/routes/bounty.go @@ -25,7 +25,7 @@ func BountyRoutes() chi.Router { r.Get("/created/{created}", bountyHandler.GetBountyByCreated) r.Get("/count/{personKey}/{tabType}", handlers.GetUserBountyCount) r.Get("/count", handlers.GetBountyCount) - r.Get("/invoice/{paymentRequest}", handlers.GetInvoiceData) + r.Get("/invoice/{paymentRequest}", bountyHandler.GetInvoiceData) r.Get("/filter/count", handlers.GetFilterCount) }) diff --git a/routes/index.go b/routes/index.go index 3efe324fa..1615fa30e 100644 --- a/routes/index.go +++ b/routes/index.go @@ -24,6 +24,7 @@ func NewRouter() *http.Server { authHandler := handlers.NewAuthHandler(db.DB) channelHandler := handlers.NewChannelHandler(db.DB) botHandler := handlers.NewBotHandler(db.DB) + bHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -75,7 +76,7 @@ func NewRouter() *http.Server { r.Post("/badges", handlers.AddOrRemoveBadge) r.Delete("/channel/{id}", channelHandler.DeleteChannel) r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin) - r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice) + r.Get("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) r.Post("/meme_upload", handlers.MemeImageUpload) r.Get("/admin/auth", authHandler.GetIsAdmin) }) diff --git a/routes/organizations.go b/routes/organizations.go index feddba0df..66041c99a 100644 --- a/routes/organizations.go +++ b/routes/organizations.go @@ -35,7 +35,7 @@ func OrganizationRoutes() chi.Router { r.Get("/budget/{uuid}", organizationHandlers.GetOrganizationBudget) r.Get("/budget/history/{uuid}", organizationHandlers.GetOrganizationBudgetHistory) r.Get("/payments/{uuid}", handlers.GetPaymentHistory) - r.Get("/poll/invoices/{uuid}", handlers.PollBudgetInvoices) + r.Get("/poll/invoices/{uuid}", organizationHandlers.PollBudgetInvoices) r.Get("/invoices/count/{uuid}", handlers.GetInvoicesCount) r.Delete("/delete/{uuid}", organizationHandlers.DeleteOrganization) })