diff --git a/handlers/tribes.go b/handlers/tribes.go index 067a25e24..041698fb2 100644 --- a/handlers/tribes.go +++ b/handlers/tribes.go @@ -17,6 +17,14 @@ import ( "github.com/stakwork/sphinx-tribes/utils" ) +type tribeHandler struct { + db db.Database +} + +func NewTribeHandler(db db.Database) *tribeHandler { + return &tribeHandler{db: db} +} + func GetAllTribes(w http.ResponseWriter, r *http.Request) { tribes := db.DB.GetAllTribes() w.WriteHeader(http.StatusOK) @@ -35,23 +43,23 @@ func GetListedTribes(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(tribes) } -func GetTribesByOwner(w http.ResponseWriter, r *http.Request) { +func (th *tribeHandler) GetTribesByOwner(w http.ResponseWriter, r *http.Request) { all := r.URL.Query().Get("all") tribes := []db.Tribe{} pubkey := chi.URLParam(r, "pubkey") if all == "true" { - tribes = db.DB.GetAllTribesByOwner(pubkey) + tribes = th.db.GetAllTribesByOwner(pubkey) } else { - tribes = db.DB.GetTribesByOwner(pubkey) + tribes = th.db.GetTribesByOwner(pubkey) } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(tribes) } -func GetTribesByAppUrl(w http.ResponseWriter, r *http.Request) { +func (th *tribeHandler) GetTribesByAppUrl(w http.ResponseWriter, r *http.Request) { tribes := []db.Tribe{} app_url := chi.URLParam(r, "app_url") - tribes = db.DB.GetTribesByAppUrl(app_url) + tribes = th.db.GetTribesByAppUrl(app_url) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(tribes) } @@ -144,15 +152,15 @@ func DeleteTribe(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(true) } -func GetTribe(w http.ResponseWriter, r *http.Request) { +func (th *tribeHandler) GetTribe(w http.ResponseWriter, r *http.Request) { uuid := chi.URLParam(r, "uuid") - tribe := db.DB.GetTribe(uuid) + tribe := th.db.GetTribe(uuid) var theTribe map[string]interface{} j, _ := json.Marshal(tribe) json.Unmarshal(j, &theTribe) - theTribe["channels"] = db.DB.GetChannelsByTribe(uuid) + theTribe["channels"] = th.db.GetChannelsByTribe(uuid) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(theTribe) diff --git a/handlers/tribes_test.go b/handlers/tribes_test.go new file mode 100644 index 000000000..bfc409c68 --- /dev/null +++ b/handlers/tribes_test.go @@ -0,0 +1,185 @@ +package handlers + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/db" + mocks "github.com/stakwork/sphinx-tribes/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestGetTribesByOwner(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + + t.Run("Should test that all tribes that an owner did not delete are returned if all=true is added to the request query", func(t *testing.T) { + // Mock data + mockPubkey := "mock_pubkey" + mockTribes := []db.Tribe{ + {UUID: "uuid", OwnerPubKey: mockPubkey, Deleted: false}, + {UUID: "uuid", OwnerPubKey: mockPubkey, Deleted: false}, + } + mockDb.On("GetAllTribesByOwner", mock.Anything).Return(mockTribes).Once() + + // Create request with "all=true" query parameter + req, err := http.NewRequest("GET", "/tribes_by_owner/"+mockPubkey+"?all=true", nil) + if err != nil { + t.Fatal(err) + } + + // Serve request + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetTribesByOwner) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.ElementsMatch(t, mockTribes, responseData) + }) + + t.Run("Should test that all tribes that are not unlisted by an owner are returned", func(t *testing.T) { + // Mock data + mockPubkey := "mock_pubkey" + mockTribes := []db.Tribe{ + {UUID: "uuid", OwnerPubKey: mockPubkey, Unlisted: false}, + {UUID: "uuid", OwnerPubKey: mockPubkey, Unlisted: false}, + } + mockDb.On("GetTribesByOwner", mock.Anything).Return(mockTribes) + + // Create request without "all=true" query parameter + req, err := http.NewRequest("GET", "/tribes/"+mockPubkey, nil) + if err != nil { + t.Fatal(err) + } + + // Serve request + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.GetTribesByOwner) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.ElementsMatch(t, mockTribes, responseData) + }) +} + +func TestGetTribe(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + + t.Run("Should test that a tribe can be returned when the right UUID is passed to the request parameter", func(t *testing.T) { + // Mock data + mockUUID := "valid_uuid" + mockTribe := db.Tribe{ + UUID: mockUUID, + } + mockChannels := []db.Channel{ + {ID: 1, TribeUUID: mockUUID}, + {ID: 2, TribeUUID: mockUUID}, + } + mockDb.On("GetTribe", mock.Anything).Return(mockTribe).Once() + mockDb.On("GetChannelsByTribe", mock.Anything).Return(mockChannels).Once() + + // Serve request + rr := httptest.NewRecorder() + rctx := chi.NewRouteContext() + rctx.URLParams.Add("uuid", mockUUID) + req, err := http.NewRequestWithContext(context.WithValue(context.Background(), chi.RouteCtxKey, rctx), http.MethodGet, "/"+mockUUID, nil) + if err != nil { + t.Fatal(err) + } + + handler := http.HandlerFunc(tHandler.GetTribe) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData map[string]interface{} + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.Equal(t, mockTribe.UUID, responseData["uuid"]) + }) + + t.Run("Should test that no tribe is returned when a nonexistent UUID is passed", func(t *testing.T) { + // Mock data + mockDb.ExpectedCalls = nil + nonexistentUUID := "nonexistent_uuid" + mockDb.On("GetTribe", nonexistentUUID).Return(db.Tribe{}).Once() + mockDb.On("GetChannelsByTribe", mock.Anything).Return([]db.Channel{}).Once() + + // Serve request + rr := httptest.NewRecorder() + rctx := chi.NewRouteContext() + rctx.URLParams.Add("uuid", nonexistentUUID) + req, err := http.NewRequestWithContext(context.WithValue(context.Background(), chi.RouteCtxKey, rctx), http.MethodGet, "/"+nonexistentUUID, nil) + if err != nil { + t.Fatal(err) + } + + handler := http.HandlerFunc(tHandler.GetTribe) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData map[string]interface{} + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.Equal(t, "", responseData["uuid"]) + }) +} + +func TestGetTribesByAppUrl(t *testing.T) { + mockDb := mocks.NewDatabase(t) + tHandler := NewTribeHandler(mockDb) + + t.Run("Should test that a tribe is returned when the right app URL is passed", func(t *testing.T) { + // Mock data + mockAppURL := "valid_app_url" + mockTribes := []db.Tribe{ + {UUID: "uuid", AppURL: mockAppURL}, + {UUID: "uuid", AppURL: mockAppURL}, + } + mockDb.On("GetTribesByAppUrl", mockAppURL).Return(mockTribes).Once() + + // Serve request + rr := httptest.NewRecorder() + rctx := chi.NewRouteContext() + rctx.URLParams.Add("app_url", mockAppURL) + req, err := http.NewRequestWithContext(context.WithValue(context.Background(), chi.RouteCtxKey, rctx), http.MethodGet, "/app_url/"+mockAppURL, nil) + if err != nil { + t.Fatal(err) + } + + handler := http.HandlerFunc(tHandler.GetTribesByAppUrl) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + var responseData []db.Tribe + err = json.Unmarshal(rr.Body.Bytes(), &responseData) + if err != nil { + t.Fatalf("Error decoding JSON response: %s", err) + } + assert.ElementsMatch(t, mockTribes, responseData) + }) +} diff --git a/routes/index.go b/routes/index.go index a149286ef..4cb494c62 100644 --- a/routes/index.go +++ b/routes/index.go @@ -20,6 +20,7 @@ import ( // NewRouter creates a chi router func NewRouter() *http.Server { r := initChi() + tribeHandlers := handlers.NewTribeHandler(db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -36,7 +37,7 @@ func NewRouter() *http.Server { r.Get("/tribe_by_feed", handlers.GetFirstTribeByFeed) r.Get("/leaderboard/{tribe_uuid}", handlers.GetLeaderBoard) r.Get("/tribe_by_un/{un}", handlers.GetTribeByUniqueName) - r.Get("/tribes_by_owner/{pubkey}", handlers.GetTribesByOwner) + r.Get("/tribes_by_owner/{pubkey}", tribeHandlers.GetTribesByOwner) r.Get("/search/bots/{query}", handlers.SearchBots) r.Get("/podcast", handlers.GetPodcast) diff --git a/routes/tribes.go b/routes/tribes.go index 6c65dfaac..72ce7cb1b 100644 --- a/routes/tribes.go +++ b/routes/tribes.go @@ -2,16 +2,18 @@ package routes import ( "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/handlers" ) func TribeRoutes() chi.Router { r := chi.NewRouter() + tribeHandlers := handlers.NewTribeHandler(db.DB) r.Group(func(r chi.Router) { r.Get("/", handlers.GetListedTribes) - r.Get("/app_url/{app_url}", handlers.GetTribesByAppUrl) + r.Get("/app_url/{app_url}", tribeHandlers.GetTribesByAppUrl) r.Get("/app_urls/{app_urls}", handlers.GetTribesByAppUrls) - r.Get("/{uuid}", handlers.GetTribe) + r.Get("/{uuid}", tribeHandlers.GetTribe) r.Get("/total", handlers.GetTotalribes) r.Post("/", handlers.CreateOrEditTribe) })