diff --git a/auth/auth.go b/auth/auth.go index 905baf0e9..842961945 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -146,6 +146,27 @@ func PubKeyContextSuperAdmin(next http.Handler) http.Handler { }) } +// ConnectionContext parses token for connection code +func ConnectionCodeContext(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("token") + + if token == "" { + fmt.Println("[auth] no token") + http.Error(w, http.StatusText(401), 401) + return + } + + if token != config.Connection_Auth { + fmt.Println("Not a super admin : auth") + http.Error(w, http.StatusText(401), 401) + return + } + ctx := context.WithValue(r.Context(), ContextKey, token) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func AdminCheck(pubkey string) bool { for _, val := range config.SuperAdmins { if val == pubkey { diff --git a/config/config.go b/config/config.go index 520cfb5b0..5feb3f7d9 100644 --- a/config/config.go +++ b/config/config.go @@ -33,6 +33,7 @@ var S3FolderName string var S3Url string var AdminCheck string var AdminDevFreePass = "FREE_PASS" +var Connection_Auth string var S3Client *s3.Client var PresignClient *s3.PresignClient @@ -51,6 +52,7 @@ func InitConfig() { S3FolderName = os.Getenv("S3_FOLDER_NAME") S3Url = os.Getenv("S3_URL") AdminCheck = os.Getenv("ADMIN_CHECK") + Connection_Auth = os.Getenv("CONNECTION_AUTH") // Add to super admins SuperAdmins = StripSuperAdmins(AdminStrings) diff --git a/db/db.go b/db/db.go index a3f18221b..415bf5775 100644 --- a/db/db.go +++ b/db/db.go @@ -1495,10 +1495,12 @@ func (db database) GetPeopleListShort(count uint32) *[]PersonInShort { return &p } -func (db database) CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) { - if c.DateCreated == nil { - now := time.Now() - c.DateCreated = &now +func (db database) CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) { + now := time.Now() + for _, code := range c { + if code.DateCreated.IsZero() { + code.DateCreated = &now + } } db.db.Create(&c) return c, nil diff --git a/db/interface.go b/db/interface.go index d0b96b2a1..2709abbb1 100644 --- a/db/interface.go +++ b/db/interface.go @@ -79,7 +79,7 @@ type Database interface { CountBounties() uint64 GetPeopleListShort(count uint32) *[]PersonInShort GetConnectionCode() ConnectionCodesShort - CreateConnectionCode(c ConnectionCodes) (ConnectionCodes, error) + CreateConnectionCode(c []ConnectionCodes) ([]ConnectionCodes, error) GetLnUser(lnKey string) int64 CreateLnUser(lnKey string) (Person, error) GetBountiesLeaderboard() []LeaderData diff --git a/handlers/auth.go b/handlers/auth.go index bd46a4040..858f680a4 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "time" "github.com/form3tech-oss/jwt-go" "github.com/stakwork/sphinx-tribes/auth" @@ -54,16 +53,21 @@ func (ah *authHandler) GetIsAdmin(w http.ResponseWriter, r *http.Request) { } func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Request) { - code := db.ConnectionCodes{} - now := time.Now() + codeArr := []db.ConnectionCodes{} + codeStrArr := []string{} body, err := io.ReadAll(r.Body) r.Body.Close() - err = json.Unmarshal(body, &code) + err = json.Unmarshal(body, &codeStrArr) - code.IsUsed = false - code.DateCreated = &now + for _, code := range codeStrArr { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) + } if err != nil { fmt.Println(err) @@ -71,13 +75,15 @@ func (ah *authHandler) CreateConnectionCode(w http.ResponseWriter, r *http.Reque return } - _, err = ah.db.CreateConnectionCode(code) + _, err = ah.db.CreateConnectionCode(codeArr) if err != nil { fmt.Println("=> ERR create connection code", err) w.WriteHeader(http.StatusBadRequest) return } + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode("Codes created successfully") } func (ah *authHandler) GetConnectionCode(w http.ResponseWriter, _ *http.Request) { diff --git a/handlers/auth_test.go b/handlers/auth_test.go index da98c96ea..461adde96 100644 --- a/handlers/auth_test.go +++ b/handlers/auth_test.go @@ -18,7 +18,6 @@ import ( "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) func TestGetAdminPubkeys(t *testing.T) { @@ -52,16 +51,21 @@ func TestGetAdminPubkeys(t *testing.T) { } func TestCreateConnectionCode(t *testing.T) { - mockDb := mocks.NewDatabase(t) aHandler := NewAuthHandler(mockDb) t.Run("should create connection code successful", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, nil).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, nil).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) @@ -77,12 +81,18 @@ func TestCreateConnectionCode(t *testing.T) { }) t.Run("should return error if failed to add connection code", func(t *testing.T) { - codeToBeInserted := db.ConnectionCodes{ - ConnectionString: "custom connection string", + codeToBeInserted := []string{"custom connection string", "custom connection string 2"} + + codeArr := []db.ConnectionCodes{} + for _, code := range codeToBeInserted { + code := db.ConnectionCodes{ + ConnectionString: code, + IsUsed: false, + } + codeArr = append(codeArr, code) } - mockDb.On("CreateConnectionCode", mock.MatchedBy(func(code db.ConnectionCodes) bool { - return code.IsUsed == false && code.ConnectionString == codeToBeInserted.ConnectionString - })).Return(codeToBeInserted, errors.New("failed to create connection")).Once() + + mockDb.On("CreateConnectionCode", codeArr).Return(codeArr, errors.New("failed to create connection")).Once() body, _ := json.Marshal(codeToBeInserted) req, err := http.NewRequest("POST", "/connectioncodes", bytes.NewBuffer(body)) diff --git a/handlers/bounty.go b/handlers/bounty.go index f6175806e..350e5fe07 100644 --- a/handlers/bounty.go +++ b/handlers/bounty.go @@ -608,15 +608,14 @@ func formatPayError(errorMsg string) db.InvoicePayError { } } -func GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { +func (h *bountyHandler) GetLightningInvoice(payment_request string) (db.InvoiceResult, db.InvoiceError) { url := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, payment_request) - client := &http.Client{} req, err := http.NewRequest(http.MethodGet, url, nil) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -695,9 +694,9 @@ func (h *bountyHandler) PayLightningInvoice(payment_request string) (db.InvoiceP } } -func GetInvoiceData(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) GetInvoiceData(w http.ResponseWriter, r *http.Request) { paymentRequest := chi.URLParam(r, "paymentRequest") - invoiceData, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceData, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -709,7 +708,7 @@ func GetInvoiceData(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(invoiceData) } -func PollInvoice(w http.ResponseWriter, r *http.Request) { +func (h *bountyHandler) PollInvoice(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) paymentRequest := chi.URLParam(r, "paymentRequest") @@ -721,7 +720,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { return } - invoiceRes, invoiceErr := GetLightningInvoice(paymentRequest) + invoiceRes, invoiceErr := h.GetLightningInvoice(paymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -731,14 +730,14 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { // Todo if an invoice is settled - invoice := db.DB.GetInvoice(paymentRequest) - invData := db.DB.GetUserInvoiceData(paymentRequest) - dbInvoice := db.DB.GetInvoice(paymentRequest) + invoice := h.db.GetInvoice(paymentRequest) + invData := h.db.GetUserInvoiceData(paymentRequest) + dbInvoice := h.db.GetInvoice(paymentRequest) // Make any change only if the invoice has not been settled if !dbInvoice.Status { if invoice.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(invoice) + h.db.AddAndUpdateBudget(invoice) } else if invoice.Type == "KEYSEND" { url := fmt.Sprintf("%s/payment", config.RelayUrl) @@ -748,12 +747,11 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { jsonBody := []byte(bodyData) - client := &http.Client{} req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) req.Header.Set("x-user-token", config.RelayAuthKey) req.Header.Set("Content-Type", "application/json") - res, _ := client.Do(req) + res, _ := h.httpClient.Do(req) if err != nil { log.Printf("Request Failed: %s", err) @@ -769,13 +767,13 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { keysendRes := db.KeysendSuccess{} err = json.Unmarshal(body, &keysendRes) - bounty, err := db.DB.GetBountyByCreated(uint(invData.Created)) + bounty, err := h.db.GetBountyByCreated(uint(invData.Created)) if err == nil { bounty.Paid = true } - db.DB.UpdateBounty(bounty) + h.db.UpdateBounty(bounty) } else { // Unmarshal result keysendError := db.KeysendError{} @@ -784,7 +782,7 @@ func PollInvoice(w http.ResponseWriter, r *http.Request) { } } // Update the invoice status - db.DB.UpdateInvoice(paymentRequest) + h.db.UpdateInvoice(paymentRequest) } } diff --git a/handlers/bounty_test.go b/handlers/bounty_test.go index 19277a34f..934862f92 100644 --- a/handlers/bounty_test.go +++ b/handlers/bounty_test.go @@ -1568,3 +1568,121 @@ func TestBountyBudgetWithdraw(t *testing.T) { mockHttpClient.AssertCalled(t, "Do", mock.AnythingOfType("*http.Request")) }) } + +func TestPollInvoice(t *testing.T) { + ctx := context.Background() + mockDb := &dbMocks.Database{} + mockHttpClient := &mocks.HttpClient{} + bHandler := NewBountyHandler(mockHttpClient, mockDb) + + unauthorizedCtx := context.WithValue(ctx, auth.ContextKey, "") + authorizedCtx := context.WithValue(ctx, auth.ContextKey, "valid-key") + + t.Run("Should test that a 401 error is returned if a user is unauthorized", func(t *testing.T) { + r := chi.NewRouter() + r.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(unauthorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + r.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code, "Expected 401 error if a user is unauthorized") + }) + + t.Run("Should test that a 403 error is returned if there is an invoice error", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": false, "error": "Internel server error"}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 500, + Body: r, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusForbidden, rr.Code, "Expected 403 error if there is an invoice error") + mockHttpClient.AssertExpectations(t) + }) + + t.Run("Should mock relay payment is successful update the bounty associated with the invoice and set the paid as true", func(t *testing.T) { + expectedUrl := fmt.Sprintf("%s/invoice?payment_request=%s", config.RelayUrl, "1") + + r := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "settled": true, "payment_request": "1", "payment_hash": "payment_hash", "preimage": "preimage", "Amount": "1000"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + return req.Method == http.MethodGet && expectedUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey + })).Return(&http.Response{ + StatusCode: 200, + Body: r, + }, nil).Once() + + bountyID := uint(1) + bounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + } + + now := time.Now() + expectedBounty := db.Bounty{ + ID: bountyID, + OrgUuid: "org-1", + Assignee: "assignee-1", + Price: uint(1000), + Paid: true, + PaidDate: &now, + CompletionDate: &now, + } + + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Type: "KEYSEND"}) + mockDb.On("GetUserInvoiceData", "1").Return(db.UserInvoiceData{Amount: 1000, UserPubkey: "UserPubkey", RouteHint: "RouteHint", Created: 1234}) + mockDb.On("GetInvoice", "1").Return(db.InvoiceList{Status: false}) + mockDb.On("GetBountyByCreated", uint(1234)).Return(bounty, nil) + mockDb.On("UpdateBounty", mock.AnythingOfType("db.Bounty")).Run(func(args mock.Arguments) { + updatedBounty := args.Get(0).(db.Bounty) + assert.True(t, updatedBounty.Paid) + }).Return(expectedBounty, nil).Once() + mockDb.On("UpdateInvoice", "1").Return(db.InvoiceList{}).Once() + + expectedPaymentUrl := fmt.Sprintf("%s/payment", config.RelayUrl) + expectedPaymentBody := `{"amount": 1000, "destination_key": "UserPubkey", "route_hint": "RouteHint", "text": "memotext added for notification"}` + + r2 := io.NopCloser(bytes.NewReader([]byte(`{"success": true, "response": { "sumAmount": "1"}}`))) + mockHttpClient.On("Do", mock.MatchedBy(func(req *http.Request) bool { + bodyByt, _ := io.ReadAll(req.Body) + return req.Method == http.MethodPost && expectedPaymentUrl == req.URL.String() && req.Header.Get("x-user-token") == config.RelayAuthKey && expectedPaymentBody == string(bodyByt) + })).Return(&http.Response{ + StatusCode: 200, + Body: r2, + }, nil).Once() + + ro := chi.NewRouter() + ro.Post("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) + + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/poll/invoice/1", bytes.NewBufferString(`{}`)) + if err != nil { + t.Fatal(err) + } + + ro.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + mockHttpClient.AssertExpectations(t) + }) +} diff --git a/handlers/metrics_test.go b/handlers/metrics_test.go index 7fc0de3f4..c00de596d 100644 --- a/handlers/metrics_test.go +++ b/handlers/metrics_test.go @@ -16,6 +16,7 @@ import ( "github.com/stakwork/sphinx-tribes/db" mocks "github.com/stakwork/sphinx-tribes/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestBountyMetrics(t *testing.T) { @@ -345,3 +346,69 @@ func TestConvertMetricsToCSV(t *testing.T) { }) } + +func TestMetricsBountiesProviders(t *testing.T) { + ctx := context.Background() + mockDb := mocks.NewDatabase(t) + mh := NewMetricHandler(mockDb) + unauthorizedCtx := context.WithValue(context.Background(), auth.ContextKey, "") + authorizedCtx := context.WithValue(ctx, auth.ContextKey, "valid-key") + + t.Run("should return 401 error if user is unauthorized", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(mh.MetricsBountiesProviders) + + req, err := http.NewRequestWithContext(unauthorizedCtx, http.MethodPost, "/bounties/providers", nil) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("should return 406 error if wrong data is passed", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(mh.MetricsBountiesProviders) + + invalidJson := []byte(`{"start_date": "2021-01-01"`) + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/bounties/providers", bytes.NewReader(invalidJson)) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotAcceptable, rr.Code) + }) + + t.Run("should return bounty providers and 200 status code if there is no error", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(mh.MetricsBountiesProviders) + + validJson := []byte(`{"start_date": "2021-01-01", "end_date": "2021-12-31"}`) + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/bounties/providers", bytes.NewReader(validJson)) + if err != nil { + t.Fatal(err) + } + + expectedProviders := []db.Person{ + {ID: 1, OwnerAlias: "Provider One"}, + {ID: 2, OwnerAlias: "Provider Two"}, + } + + mockDb.On("GetBountiesProviders", mock.Anything, req).Return(expectedProviders).Once() + + handler.ServeHTTP(rr, req) + + var actualProviders []db.Person + err = json.Unmarshal(rr.Body.Bytes(), &actualProviders) + if err != nil { + t.Fatal("Failed to unmarshal response:", err) + } + + assert.Equal(t, http.StatusOK, rr.Code) + assert.EqualValues(t, expectedProviders, actualProviders) + }) +} diff --git a/handlers/organizations.go b/handlers/organizations.go index 87a5dcdfe..7022c7eb5 100644 --- a/handlers/organizations.go +++ b/handlers/organizations.go @@ -18,6 +18,7 @@ import ( type organizationHandler struct { db db.Database generateBountyHandler func(bounties []db.Bounty) []db.BountyResponse + getLightningInvoice func(payment_request string) (db.InvoiceResult, db.InvoiceError) } func NewOrganizationHandler(db db.Database) *organizationHandler { @@ -25,6 +26,7 @@ func NewOrganizationHandler(db db.Database) *organizationHandler { return &organizationHandler{ db: db, generateBountyHandler: bHandler.GenerateBountyResponse, + getLightningInvoice: bHandler.GetLightningInvoice, } } @@ -630,7 +632,7 @@ func GetPaymentHistory(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(paymentHistoryData) } -func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { +func (oh *organizationHandler) PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) uuid := chi.URLParam(r, "uuid") @@ -641,10 +643,10 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { return } - orgInvoices := db.DB.GetOrganizationInvoices(uuid) + orgInvoices := oh.db.GetOrganizationInvoices(uuid) for _, inv := range orgInvoices { - invoiceRes, invoiceErr := GetLightningInvoice(inv.PaymentRequest) + invoiceRes, invoiceErr := oh.getLightningInvoice(inv.PaymentRequest) if invoiceErr.Error != "" { w.WriteHeader(http.StatusForbidden) @@ -654,9 +656,9 @@ func PollBudgetInvoices(w http.ResponseWriter, r *http.Request) { if invoiceRes.Response.Settled { if !inv.Status && inv.Type == "BUDGET" { - db.DB.AddAndUpdateBudget(inv) + oh.db.AddAndUpdateBudget(inv) // Update the invoice status - db.DB.UpdateInvoice(inv.PaymentRequest) + oh.db.UpdateInvoice(inv.PaymentRequest) } } } diff --git a/handlers/tribes.go b/handlers/tribes.go index 61f90f495..467feb60b 100644 --- a/handlers/tribes.go +++ b/handlers/tribes.go @@ -543,7 +543,7 @@ func GenerateInvoice(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(invoiceRes) } -func GenerateBudgetInvoice(w http.ResponseWriter, r *http.Request) { +func (th *tribeHandler) GenerateBudgetInvoice(w http.ResponseWriter, r *http.Request) { invoice := db.BudgetInvoiceRequest{} body, err := io.ReadAll(r.Body) @@ -612,8 +612,8 @@ func GenerateBudgetInvoice(w http.ResponseWriter, r *http.Request) { Status: false, } - db.DB.AddPaymentHistory(paymentHistory) - db.DB.AddInvoice(newInvoice) + th.db.AddPaymentHistory(paymentHistory) + th.db.AddInvoice(newInvoice) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(invoiceRes) diff --git a/handlers/tribes_test.go b/handlers/tribes_test.go index f9654e721..8b1a6e96a 100644 --- a/handlers/tribes_test.go +++ b/handlers/tribes_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "github.com/stakwork/sphinx-tribes/config" "net/http" "net/http/httptest" "strings" @@ -604,3 +605,136 @@ func TestGetListedTribes(t *testing.T) { }) } + +func TestGenerateBudgetInvoice(t *testing.T) { + ctx := context.Background() + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + authorizedCtx := context.WithValue(ctx, auth.ContextKey, "valid-key") + + userAmount := uint(1000) + invoiceResponse := db.InvoiceResponse{ + Succcess: true, + Response: db.Invoice{ + Invoice: "example_invoice", + }, + } + + t.Run("Should test that a wrong Post body returns a 406 error", func(t *testing.T) { + invalidBody := []byte(`"key": "value"`) + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/budgetinvoices", bytes.NewBuffer(invalidBody)) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GenerateBudgetInvoice) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusNotAcceptable, rr.Code) + }) + + t.Run("Should mock a call to relay /invoices with the correct body", func(t *testing.T) { + + mockDb.On("AddPaymentHistory", mock.AnythingOfType("db.PaymentHistory")).Return(db.PaymentHistory{}, nil) + mockDb.On("AddInvoice", mock.AnythingOfType("db.InvoiceList")).Return(db.InvoiceList{}, nil) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + expectedBody := map[string]interface{}{"amount": float64(0), "memo": "Budget Invoice"} + var body map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&body) + assert.NoError(t, err) + + assert.Equal(t, expectedBody, body) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{"result": "success"}) + })) + defer ts.Close() + + config.RelayUrl = ts.URL + + reqBody := map[string]interface{}{"amount": 0} + bodyBytes, _ := json.Marshal(reqBody) + + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/budgetinvoices", bytes.NewBuffer(bodyBytes)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GenerateBudgetInvoice) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("Should test that the amount passed by the user is equal to the amount sent for invoice generation", func(t *testing.T) { + + userAmount := float64(1000) + + mockDb.On("AddPaymentHistory", mock.AnythingOfType("db.PaymentHistory")).Return(db.PaymentHistory{}, nil) + mockDb.On("AddInvoice", mock.AnythingOfType("db.InvoiceList")).Return(db.InvoiceList{}, nil) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]interface{} + err := json.NewDecoder(r.Body).Decode(&body) + assert.NoError(t, err) + + assert.Equal(t, userAmount, body["amount"]) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{"result": "success"}) + })) + defer ts.Close() + + config.RelayUrl = ts.URL + + reqBody := map[string]interface{}{"amount": userAmount} + bodyBytes, _ := json.Marshal(reqBody) + + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/budgetinvoices", bytes.NewBuffer(bodyBytes)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GenerateBudgetInvoice) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("Should add payments to the payment history and invoice to the invoice list upon successful relay call", func(t *testing.T) { + expectedPaymentHistory := db.PaymentHistory{Amount: userAmount} + expectedInvoice := db.InvoiceList{PaymentRequest: invoiceResponse.Response.Invoice} + + mockDb.On("AddPaymentHistory", mock.AnythingOfType("db.PaymentHistory")).Return(expectedPaymentHistory, nil) + mockDb.On("AddInvoice", mock.AnythingOfType("db.InvoiceList")).Return(expectedInvoice, nil) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(invoiceResponse) + })) + defer ts.Close() + + config.RelayUrl = ts.URL + + reqBody := map[string]interface{}{"amount": userAmount} + bodyBytes, _ := json.Marshal(reqBody) + req, err := http.NewRequestWithContext(authorizedCtx, http.MethodPost, "/budgetinvoices", bytes.NewBuffer(bodyBytes)) + assert.NoError(t, err) + + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GenerateBudgetInvoice) + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + var response db.InvoiceResponse + err = json.Unmarshal(rr.Body.Bytes(), &response) + assert.NoError(t, err) + assert.True(t, response.Succcess, "Invoice generation should be successful") + assert.Equal(t, "example_invoice", response.Response.Invoice, "The invoice in the response should match the mock") + + mockDb.AssertCalled(t, "AddPaymentHistory", mock.AnythingOfType("db.PaymentHistory")) + mockDb.AssertCalled(t, "AddInvoice", mock.AnythingOfType("db.InvoiceList")) + }) +} diff --git a/mocks/Database.go b/mocks/Database.go index bc189ad35..b44f69f67 100644 --- a/mocks/Database.go +++ b/mocks/Database.go @@ -643,25 +643,27 @@ func (_c *Database_CreateChannel_Call) RunAndReturn(run func(db.Channel) (db.Cha } // CreateConnectionCode provides a mock function with given fields: c -func (_m *Database) CreateConnectionCode(c db.ConnectionCodes) (db.ConnectionCodes, error) { +func (_m *Database) CreateConnectionCode(c []db.ConnectionCodes) ([]db.ConnectionCodes, error) { ret := _m.Called(c) if len(ret) == 0 { panic("no return value specified for CreateConnectionCode") } - var r0 db.ConnectionCodes + var r0 []db.ConnectionCodes var r1 error - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) (db.ConnectionCodes, error)); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)); ok { return rf(c) } - if rf, ok := ret.Get(0).(func(db.ConnectionCodes) db.ConnectionCodes); ok { + if rf, ok := ret.Get(0).(func([]db.ConnectionCodes) []db.ConnectionCodes); ok { r0 = rf(c) } else { - r0 = ret.Get(0).(db.ConnectionCodes) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]db.ConnectionCodes) + } } - if rf, ok := ret.Get(1).(func(db.ConnectionCodes) error); ok { + if rf, ok := ret.Get(1).(func([]db.ConnectionCodes) error); ok { r1 = rf(c) } else { r1 = ret.Error(1) @@ -676,24 +678,24 @@ type Database_CreateConnectionCode_Call struct { } // CreateConnectionCode is a helper method to define mock.On call -// - c db.ConnectionCodes +// - c []db.ConnectionCodes func (_e *Database_Expecter) CreateConnectionCode(c interface{}) *Database_CreateConnectionCode_Call { return &Database_CreateConnectionCode_Call{Call: _e.mock.On("CreateConnectionCode", c)} } -func (_c *Database_CreateConnectionCode_Call) Run(run func(c db.ConnectionCodes)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Run(run func(c []db.ConnectionCodes)) *Database_CreateConnectionCode_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(db.ConnectionCodes)) + run(args[0].([]db.ConnectionCodes)) }) return _c } -func (_c *Database_CreateConnectionCode_Call) Return(_a0 db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) Return(_a0 []db.ConnectionCodes, _a1 error) *Database_CreateConnectionCode_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func(db.ConnectionCodes) (db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { +func (_c *Database_CreateConnectionCode_Call) RunAndReturn(run func([]db.ConnectionCodes) ([]db.ConnectionCodes, error)) *Database_CreateConnectionCode_Call { _c.Call.Return(run) return _c } diff --git a/routes/bounty.go b/routes/bounty.go index f25e6050e..6e5049747 100644 --- a/routes/bounty.go +++ b/routes/bounty.go @@ -25,7 +25,7 @@ func BountyRoutes() chi.Router { r.Get("/created/{created}", bountyHandler.GetBountyByCreated) r.Get("/count/{personKey}/{tabType}", handlers.GetUserBountyCount) r.Get("/count", handlers.GetBountyCount) - r.Get("/invoice/{paymentRequest}", handlers.GetInvoiceData) + r.Get("/invoice/{paymentRequest}", bountyHandler.GetInvoiceData) r.Get("/filter/count", handlers.GetFilterCount) }) diff --git a/routes/connection_codes.go b/routes/connection_codes.go index 4a5e6bee7..b151ff5e5 100644 --- a/routes/connection_codes.go +++ b/routes/connection_codes.go @@ -2,6 +2,7 @@ package routes import ( "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" ) @@ -10,8 +11,12 @@ func ConnectionCodesRoutes() chi.Router { r := chi.NewRouter() authHandler := handlers.NewAuthHandler(db.DB) r.Group(func(r chi.Router) { - r.Post("/", authHandler.CreateConnectionCode) r.Get("/", authHandler.GetConnectionCode) }) + + r.Group(func(r chi.Router) { + r.Use(auth.ConnectionCodeContext) + r.Post("/", authHandler.CreateConnectionCode) + }) return r } diff --git a/routes/index.go b/routes/index.go index fc3b0cb48..1615fa30e 100644 --- a/routes/index.go +++ b/routes/index.go @@ -24,6 +24,7 @@ func NewRouter() *http.Server { authHandler := handlers.NewAuthHandler(db.DB) channelHandler := handlers.NewChannelHandler(db.DB) botHandler := handlers.NewBotHandler(db.DB) + bHandler := handlers.NewBountyHandler(http.DefaultClient, db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -75,7 +76,7 @@ func NewRouter() *http.Server { r.Post("/badges", handlers.AddOrRemoveBadge) r.Delete("/channel/{id}", channelHandler.DeleteChannel) r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin) - r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice) + r.Get("/poll/invoice/{paymentRequest}", bHandler.PollInvoice) r.Post("/meme_upload", handlers.MemeImageUpload) r.Get("/admin/auth", authHandler.GetIsAdmin) }) @@ -85,7 +86,7 @@ func NewRouter() *http.Server { r.Get("/lnauth", handlers.GetLnurlAuth) r.Get("/refresh_jwt", authHandler.RefreshToken) r.Post("/invoices", handlers.GenerateInvoice) - r.Post("/budgetinvoices", handlers.GenerateBudgetInvoice) + r.Post("/budgetinvoices", tribeHandlers.GenerateBudgetInvoice) }) PORT := os.Getenv("PORT") diff --git a/routes/organizations.go b/routes/organizations.go index feddba0df..66041c99a 100644 --- a/routes/organizations.go +++ b/routes/organizations.go @@ -35,7 +35,7 @@ func OrganizationRoutes() chi.Router { r.Get("/budget/{uuid}", organizationHandlers.GetOrganizationBudget) r.Get("/budget/history/{uuid}", organizationHandlers.GetOrganizationBudgetHistory) r.Get("/payments/{uuid}", handlers.GetPaymentHistory) - r.Get("/poll/invoices/{uuid}", handlers.PollBudgetInvoices) + r.Get("/poll/invoices/{uuid}", organizationHandlers.PollBudgetInvoices) r.Get("/invoices/count/{uuid}", handlers.GetInvoicesCount) r.Delete("/delete/{uuid}", organizationHandlers.DeleteOrganization) }) diff --git a/utils/utils.go b/utils/utils.go index 95f41417a..af3f0a51d 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -7,10 +7,9 @@ import ( ) func GetPaginationParams(r *http.Request) (int, int, string, string, string) { - // there are cases when the request is not passed in if r == nil { - return 0, -1, "updated", "asc", "" + return 0, 1, "updated", "asc", "" } keys := r.URL.Query()