From 5d6361f0ac6fa104ad0f26085987cce77ba87c58 Mon Sep 17 00:00:00 2001 From: aliraza556 Date: Fri, 29 Nov 2024 03:20:59 +0500 Subject: [PATCH] refactor: standardize ticket handler with http client and auth context and ticket status --- db/structs.go | 11 +++++-- db/tickets.go | 17 +++++++++++ handlers/ticket.go | 55 ++++++++++++++++++++++++++++++--- handlers/ticket_test.go | 67 +++++++++++++++++++++++++++++++++++------ routes/ticket_routes.go | 12 ++++++-- 5 files changed, 142 insertions(+), 20 deletions(-) diff --git a/db/structs.go b/db/structs.go index a73000c5e..4dd24946d 100644 --- a/db/structs.go +++ b/db/structs.go @@ -948,8 +948,13 @@ type WfRequest struct { type TicketStatus string const ( - DraftTicket TicketStatus = "draft" - CompletedTicket TicketStatus = "completed" + DraftTicket TicketStatus = "DRAFT" + ReadyTicket TicketStatus = "READY" + InProgressTicket TicketStatus = "IN_PROGRESS" + TestTicket TicketStatus = "TEST" + DeployTicket TicketStatus = "DEPLOY" + PayTicket TicketStatus = "PAY" + CompletedTicket TicketStatus = "COMPLETED" ) type Tickets struct { @@ -962,7 +967,7 @@ type Tickets struct { Sequence int `gorm:"type:integer;not null;index:composite_index"` Dependency []int `gorm:"type:integer[]"` Description string `gorm:"type:text"` - Status TicketStatus `gorm:"type:varchar(50);not null;default:'draft'"` + Status TicketStatus `gorm:"type:varchar(50);not null;default:'DRAFT'"` Version int `gorm:"type:integer" json:"version"` CreatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp" json:"created_at"` UpdatedAt time.Time `gorm:"type:timestamp;not null;default:current_timestamp" json:"updated_at"` diff --git a/db/tickets.go b/db/tickets.go index 2d0375483..5e12c83b8 100644 --- a/db/tickets.go +++ b/db/tickets.go @@ -52,6 +52,15 @@ func (db database) GetTicket(uuid string) (Tickets, error) { return ticket, nil } +func IsValidTicketStatus(status TicketStatus) bool { + switch status { + case DraftTicket, ReadyTicket, InProgressTicket, TestTicket, DeployTicket, PayTicket, CompletedTicket: + return true + default: + return false + } +} + func (db database) UpdateTicket(ticket Tickets) (Tickets, error) { if ticket.UUID == uuid.Nil { return Tickets{}, errors.New("ticket UUID is required") @@ -61,6 +70,10 @@ func (db database) UpdateTicket(ticket Tickets) (Tickets, error) { return Tickets{}, errors.New("feature_uuid, phase_uuid, and name are required") } + if ticket.Status != "" && !IsValidTicketStatus(ticket.Status) { + return Tickets{}, errors.New("invalid ticket status") + } + var existingTicket Tickets result := db.db.Where("uuid = ?", ticket.UUID).First(&existingTicket) @@ -70,6 +83,10 @@ func (db database) UpdateTicket(ticket Tickets) (Tickets, error) { if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { ticket.CreatedAt = now + + if ticket.Status == "" { + ticket.Status = DraftTicket + } if err := db.db.Create(&ticket).Error; err != nil { return Tickets{}, fmt.Errorf("failed to create ticket: %w", err) } diff --git a/handlers/ticket.go b/handlers/ticket.go index 62deec014..5687defd4 100644 --- a/handlers/ticket.go +++ b/handlers/ticket.go @@ -9,12 +9,14 @@ import ( "github.com/go-chi/chi" "github.com/google/uuid" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stakwork/sphinx-tribes/utils" ) type ticketHandler struct { - db db.Database + httpClient HttpClient + db db.Database } type TicketResponse struct { @@ -24,9 +26,10 @@ type TicketResponse struct { Errors []string `json:"errors,omitempty"` } -func NewTicketHandler(database db.Database) *ticketHandler { +func NewTicketHandler(httpClient HttpClient, database db.Database) *ticketHandler { return &ticketHandler{ - db: database, + httpClient: httpClient, + db: database, } } @@ -51,6 +54,15 @@ func (th *ticketHandler) GetTicket(w http.ResponseWriter, r *http.Request) { } func (th *ticketHandler) UpdateTicket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) + + if pubKeyFromAuth == "" { + fmt.Println("[ticket] no pubkey from auth") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + uuidStr := chi.URLParam(r, "uuid") if uuidStr == "" { http.Error(w, "UUID is required", http.StatusBadRequest) @@ -68,6 +80,7 @@ func (th *ticketHandler) UpdateTicket(w http.ResponseWriter, r *http.Request) { http.Error(w, "Error reading request body", http.StatusBadRequest) return } + defer r.Body.Close() var ticket db.Tickets if err := json.Unmarshal(body, &ticket); err != nil { @@ -77,6 +90,11 @@ func (th *ticketHandler) UpdateTicket(w http.ResponseWriter, r *http.Request) { ticket.UUID = ticketUUID + if ticket.Status != "" && !db.IsValidTicketStatus(ticket.Status) { + http.Error(w, "Invalid ticket status", http.StatusBadRequest) + return + } + updatedTicket, err := th.db.UpdateTicket(ticket) if err != nil { if err.Error() == "feature_uuid, phase_uuid, and name are required" { @@ -89,8 +107,16 @@ func (th *ticketHandler) UpdateTicket(w http.ResponseWriter, r *http.Request) { utils.RespondWithJSON(w, http.StatusOK, updatedTicket) } - func (th *ticketHandler) DeleteTicket(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) + + if pubKeyFromAuth == "" { + fmt.Println("[ticket] no pubkey from auth") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + uuid := chi.URLParam(r, "uuid") if uuid == "" { http.Error(w, "UUID is required", http.StatusBadRequest) @@ -111,6 +137,15 @@ func (th *ticketHandler) DeleteTicket(w http.ResponseWriter, r *http.Request) { } func (th *ticketHandler) PostTicketDataToStakwork(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) + + if pubKeyFromAuth == "" { + fmt.Println("[ticket] no pubkey from auth") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { utils.RespondWithJSON(w, http.StatusBadRequest, TicketResponse{ @@ -120,6 +155,7 @@ func (th *ticketHandler) PostTicketDataToStakwork(w http.ResponseWriter, r *http }) return } + defer r.Body.Close() var ticket db.Tickets if err := json.Unmarshal(body, &ticket); err != nil { @@ -167,12 +203,22 @@ func (th *ticketHandler) PostTicketDataToStakwork(w http.ResponseWriter, r *http } func (th *ticketHandler) ProcessTicketReview(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + pubKeyFromAuth, _ := ctx.Value(auth.ContextKey).(string) + + if pubKeyFromAuth == "" { + fmt.Println("[ticket] no pubkey from auth") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { log.Printf("Error reading request body: %v", err) http.Error(w, "Error reading request body", http.StatusBadRequest) return } + defer r.Body.Close() var reviewReq utils.TicketReviewRequest if err := json.Unmarshal(body, &reviewReq); err != nil { @@ -204,6 +250,5 @@ func (th *ticketHandler) ProcessTicketReview(w http.ResponseWriter, r *http.Requ } log.Printf("Successfully updated ticket %s", reviewReq.TicketUUID) - utils.RespondWithJSON(w, http.StatusOK, updatedTicket) } diff --git a/handlers/ticket_test.go b/handlers/ticket_test.go index e4d83d656..4fcb481e4 100644 --- a/handlers/ticket_test.go +++ b/handlers/ticket_test.go @@ -10,6 +10,7 @@ import ( "github.com/go-chi/chi" "github.com/google/uuid" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" "github.com/stretchr/testify/assert" ) @@ -18,7 +19,7 @@ func TestGetTicket(t *testing.T) { teardownSuite := SetupSuite(t) defer teardownSuite(t) - tHandler := NewTicketHandler(db.TestDB) + tHandler := NewTicketHandler(&http.Client{}, db.TestDB) person := db.Person{ Uuid: uuid.New().String(), @@ -129,7 +130,7 @@ func TestUpdateTicket(t *testing.T) { teardownSuite := SetupSuite(t) defer teardownSuite(t) - tHandler := NewTicketHandler(db.TestDB) + tHandler := NewTicketHandler(&http.Client{}, db.TestDB) person := db.Person{ Uuid: uuid.New().String(), @@ -179,6 +180,19 @@ func TestUpdateTicket(t *testing.T) { } createdTicket, _ := db.TestDB.UpdateTicket(ticket) + t.Run("should return 401 if no auth token", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.UpdateTicket) + + req, err := http.NewRequest(http.MethodPut, "/tickets/", nil) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + t.Run("should return 400 if UUID is empty", func(t *testing.T) { rr := httptest.NewRecorder() handler := http.HandlerFunc(tHandler.UpdateTicket) @@ -188,6 +202,9 @@ func TestUpdateTicket(t *testing.T) { t.Fatal(err) } + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(ctx) + handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) }) @@ -198,11 +215,14 @@ func TestUpdateTicket(t *testing.T) { rctx := chi.NewRouteContext() rctx.URLParams.Add("uuid", "invalid-uuid") + req, err := http.NewRequest(http.MethodPut, "/tickets/invalid-uuid", bytes.NewReader([]byte("{}"))) if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -219,7 +239,9 @@ func TestUpdateTicket(t *testing.T) { if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -241,7 +263,9 @@ func TestUpdateTicket(t *testing.T) { if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) @@ -259,11 +283,14 @@ func TestUpdateTicket(t *testing.T) { requestBody, _ := json.Marshal(updatedTicket) rctx := chi.NewRouteContext() rctx.URLParams.Add("uuid", createdTicket.UUID.String()) + req, err := http.NewRequest(http.MethodPut, "/tickets/"+createdTicket.UUID.String(), bytes.NewReader(requestBody)) if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) @@ -284,7 +311,7 @@ func TestDeleteTicket(t *testing.T) { teardownSuite := SetupSuite(t) defer teardownSuite(t) - tHandler := NewTicketHandler(db.TestDB) + tHandler := NewTicketHandler(&http.Client{}, db.TestDB) person := db.Person{ Uuid: uuid.New().String(), @@ -334,6 +361,19 @@ func TestDeleteTicket(t *testing.T) { } createdTicket, _ := db.TestDB.UpdateTicket(ticket) + t.Run("should return 401 if no auth token", func(t *testing.T) { + rr := httptest.NewRecorder() + handler := http.HandlerFunc(tHandler.DeleteTicket) + + req, err := http.NewRequest(http.MethodDelete, "/tickets/", nil) + if err != nil { + t.Fatal(err) + } + + handler.ServeHTTP(rr, req) + assert.Equal(t, http.StatusUnauthorized, rr.Code) + }) + t.Run("should return 400 if UUID is empty", func(t *testing.T) { rr := httptest.NewRecorder() handler := http.HandlerFunc(tHandler.DeleteTicket) @@ -343,6 +383,9 @@ func TestDeleteTicket(t *testing.T) { t.Fatal(err) } + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(ctx) + handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusBadRequest, rr.Code) }) @@ -354,11 +397,14 @@ func TestDeleteTicket(t *testing.T) { nonExistentUUID := uuid.New() rctx := chi.NewRouteContext() rctx.URLParams.Add("uuid", nonExistentUUID.String()) + req, err := http.NewRequest(http.MethodDelete, "/tickets/"+nonExistentUUID.String(), nil) if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) assert.Equal(t, http.StatusNotFound, rr.Code) @@ -370,11 +416,14 @@ func TestDeleteTicket(t *testing.T) { rctx := chi.NewRouteContext() rctx.URLParams.Add("uuid", createdTicket.UUID.String()) + req, err := http.NewRequest(http.MethodDelete, "/tickets/"+createdTicket.UUID.String(), nil) if err != nil { t.Fatal(err) } - req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + + ctx := context.WithValue(req.Context(), auth.ContextKey, person.OwnerPubKey) + req = req.WithContext(context.WithValue(ctx, chi.RouteCtxKey, rctx)) handler.ServeHTTP(rr, req) diff --git a/routes/ticket_routes.go b/routes/ticket_routes.go index c921349e3..90b7bd6a8 100644 --- a/routes/ticket_routes.go +++ b/routes/ticket_routes.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/go-chi/chi" "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/db" @@ -9,16 +11,20 @@ import ( func TicketRoutes() chi.Router { r := chi.NewRouter() - ticketHandler := handlers.NewTicketHandler(db.DB) + ticketHandler := handlers.NewTicketHandler(http.DefaultClient, db.DB) + + r.Group(func(r chi.Router) { + r.Get("/{uuid}", ticketHandler.GetTicket) + }) r.Group(func(r chi.Router) { r.Use(auth.PubKeyContext) r.Post("/review/send", ticketHandler.PostTicketDataToStakwork) - r.Get("/{uuid}", ticketHandler.GetTicket) + r.Post("/review", ticketHandler.ProcessTicketReview) + r.Put("/{uuid}", ticketHandler.UpdateTicket) r.Delete("/{uuid}", ticketHandler.DeleteTicket) - r.Post("/review", ticketHandler.ProcessTicketReview) }) return r