diff --git a/handlers/channel.go b/handlers/channel.go index 627fe1e97..b969d2a42 100644 --- a/handlers/channel.go +++ b/handlers/channel.go @@ -12,7 +12,17 @@ import ( "github.com/stakwork/sphinx-tribes/db" ) -func DeleteChannel(w http.ResponseWriter, r *http.Request) { +type channelHandler struct { + db db.Database +} + +func NewChannelHandler(db db.Database) *channelHandler { + return &channelHandler{ + db: db, + } +} + +func (ch *channelHandler) DeleteChannel(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) @@ -30,8 +40,8 @@ func DeleteChannel(w http.ResponseWriter, r *http.Request) { return } - existing := db.DB.GetChannel(uint(id)) - existingTribe := db.DB.GetTribe(existing.TribeUUID) + existing := ch.db.GetChannel(uint(id)) + existingTribe := ch.db.GetTribe(existing.TribeUUID) if existing.ID == 0 { fmt.Println("existing id is 0") w.WriteHeader(http.StatusUnauthorized) @@ -43,7 +53,7 @@ func DeleteChannel(w http.ResponseWriter, r *http.Request) { return } - db.DB.UpdateChannel(uint(id), map[string]interface{}{ + ch.db.UpdateChannel(uint(id), map[string]interface{}{ "deleted": true, }) @@ -51,7 +61,7 @@ func DeleteChannel(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(true) } -func CreateChannel(w http.ResponseWriter, r *http.Request) { +func (ch *channelHandler) CreateChannel(w http.ResponseWriter, r *http.Request) { ctx := r.Context() pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) @@ -66,14 +76,14 @@ func CreateChannel(w http.ResponseWriter, r *http.Request) { } //check that the tribe has the same pubKeyFromAuth - tribe := db.DB.GetTribe(channel.TribeUUID) + tribe := ch.db.GetTribe(channel.TribeUUID) if tribe.OwnerPubKey != pubKeyFromAuth { fmt.Println(err) - w.WriteHeader(http.StatusNotAcceptable) + w.WriteHeader(http.StatusUnauthorized) return } - tribeChannels := db.DB.GetChannelsByTribe(channel.TribeUUID) + tribeChannels := ch.db.GetChannelsByTribe(channel.TribeUUID) for _, tribeChannel := range tribeChannels { if tribeChannel.Name == channel.Name { fmt.Println("Channel name already in use") @@ -83,7 +93,7 @@ func CreateChannel(w http.ResponseWriter, r *http.Request) { } } - channel, err = db.DB.CreateChannel(channel) + channel, err = ch.db.CreateChannel(channel) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusNotAcceptable) diff --git a/handlers/channel_test.go b/handlers/channel_test.go new file mode 100644 index 000000000..066d67101 --- /dev/null +++ b/handlers/channel_test.go @@ -0,0 +1,137 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi" + "github.com/stakwork/sphinx-tribes/auth" + "github.com/stakwork/sphinx-tribes/db" + mocks "github.com/stakwork/sphinx-tribes/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestCreateChannel(t *testing.T) { + mockDb := mocks.NewDatabase(t) + cHandler := NewChannelHandler(mockDb) + + // Mock data for testing + mockPubKey := "mock_pubkey" + mockTribeUUID := "mock_tribe_uuid" + mockChannelName := "mock_channel" + mockRequestBody := map[string]interface{}{ + "tribe_uuid": mockTribeUUID, + "name": mockChannelName, + } + + // Mock request body + requestBodyBytes, err := json.Marshal(mockRequestBody) + assert.NoError(t, err) + + t.Run("Should test that a user that is not authenticated cannot create a channel", func(t *testing.T) { + req, err := http.NewRequest("POST", "/channel", bytes.NewBuffer(requestBodyBytes)) + assert.NoError(t, err) + rr := httptest.NewRecorder() + + mockDb.On("GetTribe", mockTribeUUID).Return(db.Tribe{OwnerPubKey: mockPubKey}) + + cHandler.CreateChannel(rr, req) + + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + + t.Run("Should test that an authenticated user can create a channel", func(t *testing.T) { + req, err := http.NewRequest("POST", "/channel", bytes.NewBuffer(requestBodyBytes)) + assert.NoError(t, err) + req = req.WithContext(context.WithValue(req.Context(), auth.ContextKey, mockPubKey)) + rr := httptest.NewRecorder() + + mockDb.On("GetTribe", mockTribeUUID).Return(db.Tribe{OwnerPubKey: mockPubKey}) + mockDb.On("GetChannelsByTribe", mockTribeUUID).Return([]db.Channel{}) + mockDb.On("CreateChannel", mock.Anything).Return(db.Channel{}, nil) + + cHandler.CreateChannel(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("Should test that a user cannot create a channel with a name that already exists", func(t *testing.T) { + mockDb.ExpectedCalls = nil + + req, err := http.NewRequest("POST", "/channel", bytes.NewBuffer(requestBodyBytes)) + assert.NoError(t, err) + req = req.WithContext(context.WithValue(req.Context(), auth.ContextKey, mockPubKey)) + rr := httptest.NewRecorder() + + mockDb.On("GetTribe", mockTribeUUID).Return(db.Tribe{OwnerPubKey: mockPubKey}) + mockDb.On("GetChannelsByTribe", mockTribeUUID).Return([]db.Channel{{Name: mockChannelName}}) + + cHandler.CreateChannel(rr, req) + + assert.Equal(t, http.StatusNotAcceptable, rr.Code) + + // Ensure that the expected methods were called + mockDb.AssertExpectations(t) + }) +} + +func TestDeleteChannel(t *testing.T) { + ctx := context.WithValue(context.Background(), auth.ContextKey, "mock_pubkey") + mockDb := mocks.NewDatabase(t) + cHandler := NewChannelHandler(mockDb) + + // Mock data for testing + mockPubKey := "mock_pubkey" + mockChannelID := uint(1) + + t.Run("Should test that the owner of a channel can delete the channel", func(t *testing.T) { + mockDb.On("GetChannel", mockChannelID).Return(db.Channel{ID: mockChannelID, TribeUUID: "mock_tribe_uuid"}) + mockDb.On("GetTribe", "mock_tribe_uuid").Return(db.Tribe{OwnerPubKey: mockPubKey}) + mockDb.On("UpdateChannel", mockChannelID, mock.Anything).Return(true) + + // Create and Serve request + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, "DELETE", "/channel/1", nil) + if err != nil { + t.Fatal(err) + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chiCtx)) + + handler := http.HandlerFunc(cHandler.DeleteChannel) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusOK, rr.Code) + }) + + t.Run("Should test that non-channel owners cannot delete the channel, it should return a 401 error", func(t *testing.T) { + mockPubKey := "other_pubkey" + + mockDb.ExpectedCalls = nil + mockDb.On("GetChannel", mockChannelID).Return(db.Channel{ID: mockChannelID, TribeUUID: "mock_tribe_uuid"}) + mockDb.On("GetTribe", "mock_tribe_uuid").Return(db.Tribe{OwnerPubKey: mockPubKey}) + + // Create and Serve request + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(ctx, "DELETE", "/channel/1", nil) + if err != nil { + t.Fatal(err) + } + chiCtx := chi.NewRouteContext() + chiCtx.URLParams.Add("id", "1") + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, chiCtx)) + + handler := http.HandlerFunc(cHandler.DeleteChannel) + handler.ServeHTTP(rr, req) + + // Verify response + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) +} diff --git a/routes/index.go b/routes/index.go index d489f4923..d7f71c749 100644 --- a/routes/index.go +++ b/routes/index.go @@ -22,6 +22,7 @@ func NewRouter() *http.Server { r := initChi() tribeHandlers := handlers.NewTribeHandler(db.DB) authHandler := handlers.NewAuthHandler(db.DB) + channelHandler := handlers.NewChannelHandler(db.DB) r.Mount("/tribes", TribeRoutes()) r.Mount("/bots", BotsRoutes()) @@ -61,7 +62,7 @@ func NewRouter() *http.Server { r.Group(func(r chi.Router) { r.Use(auth.PubKeyContext) - r.Post("/channel", handlers.CreateChannel) + r.Post("/channel", channelHandler.CreateChannel) r.Post("/leaderboard/{tribe_uuid}", handlers.CreateLeaderBoard) r.Put("/leaderboard/{tribe_uuid}", handlers.UpdateLeaderBoard) r.Put("/tribe", tribeHandlers.CreateOrEditTribe) @@ -71,7 +72,7 @@ func NewRouter() *http.Server { r.Put("/tribepreview/{uuid}", tribeHandlers.SetTribePreview) r.Post("/verify/{challenge}", db.Verify) r.Post("/badges", handlers.AddOrRemoveBadge) - r.Delete("/channel/{id}", handlers.DeleteChannel) + r.Delete("/channel/{id}", channelHandler.DeleteChannel) r.Delete("/ticket/{pubKey}/{created}", handlers.DeleteTicketByAdmin) r.Get("/poll/invoice/{paymentRequest}", handlers.PollInvoice) r.Post("/meme_upload", handlers.MemeImageUpload)