From 78e609b2bd4923ed4fb28cb22e72ca777927313f Mon Sep 17 00:00:00 2001 From: nuric Date: Sat, 30 Nov 2024 16:09:55 +0000 Subject: [PATCH] Remove gin framework from httpapi for more control The Gin framework is good but restricts validation, JSON inputs and middleware control. By going standard library, we will have more validation control, non-JSON input and standard library compatability. --- httpapi/httpapi.go | 33 ++-- httpapi/metrics.go | 10 -- httpapi/middleware.go | 131 ++++++-------- httpapi/middleware/appheaders.go | 40 +++-- httpapi/utils/encdec.go | 50 ++++++ httpapi/v1/handlers.go | 294 ++++++++++++++++++++----------- httpapi/v2/handlers.go | 257 +++++++++++++++------------ 7 files changed, 482 insertions(+), 333 deletions(-) create mode 100644 httpapi/utils/encdec.go diff --git a/httpapi/httpapi.go b/httpapi/httpapi.go index e0a5e92..da0a643 100644 --- a/httpapi/httpapi.go +++ b/httpapi/httpapi.go @@ -4,7 +4,6 @@ import ( "net/http" "strconv" - "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" @@ -34,39 +33,35 @@ type HttpApiConfig struct { // --------------------------- -func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *gin.Engine { - router := gin.New() +func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) http.Handler { // --------------------------- var metrics *httpMetrics if cfg.EnableMetrics && reg != nil { metrics = setupAndListenMetrics(cfg, reg) } // --------------------------- - router.Use(ZerologLoggerMetrics(metrics), gin.Recovery()) + mux := http.NewServeMux() + mux.Handle("/v1/", http.StripPrefix("/v1", httpv1.SetupV1Handlers(cnode))) + mux.Handle("/v2/", http.StripPrefix("/v2", httpv2.SetupV2Handlers(cnode))) // --------------------------- - if len(cfg.ProxySecret) > 0 { - log.Info().Msg("ProxySecretMiddleware is enabled") - router.Use(ProxySecretMiddleware(cfg.ProxySecret)) - } + var handler http.Handler = mux + handler = middleware.AppHeaderMiddleware(cfg.UserPlans, handler) if cfg.WhiteListIPs == nil || (len(cfg.WhiteListIPs) == 1 && cfg.WhiteListIPs[0] == "*") { log.Warn().Strs("whiteListIPs", cfg.WhiteListIPs).Msg("WhiteListIPMiddleware is disabled") } else { - router.Use(WhiteListIPMiddleware(cfg.WhiteListIPs)) + handler = WhiteListIPMiddleware(cfg.WhiteListIPs, handler) } + if len(cfg.ProxySecret) > 0 { + log.Info().Msg("ProxySecretMiddleware is enabled") + handler = ProxySecretMiddleware(cfg.ProxySecret, handler) + } + handler = ZeroLoggerMetrics(metrics, handler) + handler = RecoverMiddleware(handler) // --------------------------- - v1 := router.Group("/v1", middleware.AppHeaderMiddleware(cfg.UserPlans)) - httpv1.SetupV1Handlers(cnode, v1) - // --------------------------- - v2 := router.Group("/v2", middleware.AppHeaderMiddleware(cfg.UserPlans)) - httpv2.SetupV2Handlers(cnode, v2) - return router + return handler } func RunHTTPServer(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *http.Server { - // --------------------------- - if !cfg.Debug { - gin.SetMode(gin.ReleaseMode) - } // --------------------------- server := &http.Server{ Addr: cfg.HttpHost + ":" + strconv.Itoa(cfg.HttpPort), diff --git a/httpapi/metrics.go b/httpapi/metrics.go index adce472..ce18a81 100644 --- a/httpapi/metrics.go +++ b/httpapi/metrics.go @@ -14,7 +14,6 @@ type httpMetrics struct { requestCount *prometheus.CounterVec requestDuration *prometheus.HistogramVec requestSize *prometheus.HistogramVec - responseSize *prometheus.HistogramVec // --------------------------- } @@ -44,19 +43,10 @@ func setupAndListenMetrics(cfg HttpApiConfig, reg *prometheus.Registry) *httpMet }, []string{"code", "method", "handler"}, ), - responseSize: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "http_response_size_bytes", - Help: "HTTP request sizes in bytes.", - Buckets: []float64{0, 1 << 10, 1 << 15, 1 << 20}, - }, - []string{"code", "method", "handler"}, - ), } reg.MustRegister(metrics.requestCount) reg.MustRegister(metrics.requestDuration) reg.MustRegister(metrics.requestSize) - reg.MustRegister(metrics.responseSize) // --------------------------- mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{Registry: reg})) diff --git a/httpapi/middleware.go b/httpapi/middleware.go index 5ea498d..74b9eaa 100644 --- a/httpapi/middleware.go +++ b/httpapi/middleware.go @@ -2,102 +2,79 @@ package httpapi import ( "net/http" + "regexp" + "runtime/debug" "slices" "strconv" - "strings" "time" - "github.com/gin-gonic/gin" - "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" - "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" ) // --------------------------- - -func ZerologLoggerMetrics(metrics *httpMetrics) gin.HandlerFunc { - return func(c *gin.Context) { - // --------------------------- - // Start timer - start := time.Now() - path := c.Request.URL.Path - raw := c.Request.URL.RawQuery - reqSize := c.Request.ContentLength - // --------------------------- - // Process request - c.Next() - // --------------------------- - // Stop timer and gather information - latency := time.Since(start) - - method := c.Request.Method - statusCode := c.Writer.Status() - lastError := c.Errors.ByType(gin.ErrorTypePrivate).Last() - - bodySize := c.Writer.Size() - - if raw != "" { - path = path + "?" + raw - } - // --------------------------- - var logEvent *zerolog.Event - if statusCode == 500 || lastError != nil { - logEvent = log.Error() - } else { - logEvent = log.Info() - } - logEvent.Err(lastError). - Dur("latency", latency). - Str("clientIP", c.ClientIP()). - Str("remoteIP", c.RemoteIP()). - Str("method", method).Str("path", path). - Int("statusCode", statusCode). - Int64("requestSize", reqSize). - Int("bodySize", bodySize). - Str("path", path) - // Extract app headers if any - appH, ok := c.Keys["appHeaders"] - if ok { - appHeaders := appH.(middleware.AppHeaders) - // We are not logging the user ID for privacy reasons - logEvent = logEvent.Str("planId", appHeaders.PlanId) - } - logEvent.Msg("HTTPAPI") - // --------------------------- +// Zerolog based middleware for logging HTTP requests +func ZeroLoggerMetrics(metrics *httpMetrics, next http.Handler) http.Handler { + handler := hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { + hlog.FromRequest(r).Info(). + Str("method", r.Method). + Stringer("url", r.URL). + Int("status", status). + Int("size", size). + Dur("duration", duration). + Msg("") if metrics != nil { - // Example handler names - // github.com/semafind/semadb/httpapi.(*SemaDBHandlers).ListCollections-fm - // github.com/semafind/semadb/httpapi.(*SemaDBHandlers).CreateCollection-fm - fullHName := c.HandlerName() - parts := strings.Split(fullHName, ".") - hname := parts[len(parts)-1][:len(parts[len(parts)-1])-3] - ssCode := strconv.Itoa(statusCode) - metrics.requestCount.WithLabelValues(ssCode, method, hname).Inc() - metrics.requestDuration.WithLabelValues(ssCode, method, hname).Observe(latency.Seconds()) - metrics.requestSize.WithLabelValues(ssCode, method, hname).Observe(float64(reqSize)) - metrics.responseSize.WithLabelValues(ssCode, method, hname).Observe(float64(bodySize)) + // Canonicalize the URL by removing url parameters + // Replace anything of the form collections/mycol23 with collections/:id + re := regexp.MustCompile(`collections/[a-zA-Z0-9]+`) + canonical := re.ReplaceAll([]byte(r.URL.Path), []byte("collections/{collectionId}")) + hname := string(canonical) + ssCode := strconv.Itoa(status) + metrics.requestCount.WithLabelValues(ssCode, r.Method, hname).Inc() + metrics.requestDuration.WithLabelValues(ssCode, r.Method, hname).Observe(duration.Seconds()) + metrics.requestSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(size)) + // metrics.responseSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(bodySize)) } - } + })(next) + handler = hlog.NewHandler(log.Logger)(handler) + return handler } // --------------------------- -func ProxySecretMiddleware(secret string) gin.HandlerFunc { +func ProxySecretMiddleware(secret string, next http.Handler) http.Handler { log.Debug().Str("proxySecret", secret).Msg("ProxySecretMiddleware") - return func(c *gin.Context) { - if c.GetHeader("X-Proxy-Secret") != secret { - c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{"error": "forbidden"}) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Proxy-Secret") != secret { + utils.Encode(w, http.StatusProxyAuthRequired, map[string]string{"error": "forbidden"}) + return } - } + next.ServeHTTP(w, r) + }) } -func WhiteListIPMiddleware(whitelist []string) gin.HandlerFunc { +func WhiteListIPMiddleware(whitelist []string, next http.Handler) http.Handler { slices.Sort(whitelist) - return func(c *gin.Context) { - remoteIP := c.RemoteIP() - _, found := slices.BinarySearch(whitelist, remoteIP) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, found := slices.BinarySearch(whitelist, r.RemoteAddr) if !found { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "forbidden"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "forbidden"}) + return } - } + next.ServeHTTP(w, r) + }) +} + +func RecoverMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Error().Interface("error", err).Msg("panic recovered") + log.Error().Str("stack", string(debug.Stack())).Msg("stack trace") + w.WriteHeader(http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) } diff --git a/httpapi/middleware/appheaders.go b/httpapi/middleware/appheaders.go index 10c2f50..f85cefe 100644 --- a/httpapi/middleware/appheaders.go +++ b/httpapi/middleware/appheaders.go @@ -1,11 +1,12 @@ package middleware import ( + "context" "fmt" "net/http" - "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" ) @@ -14,23 +15,40 @@ type AppHeaders struct { PlanId string `header:"X-Plan-Id" binding:"required"` } -func AppHeaderMiddleware(userPlans map[string]models.UserPlan) gin.HandlerFunc { - return func(c *gin.Context) { - var appHeaders AppHeaders - if err := c.ShouldBindHeader(&appHeaders); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +type contextKey string + +const appHeadersKey contextKey = "appHeaders" +const userPlanKey contextKey = "userPlan" + +func AppHeaderMiddleware(userPlans map[string]models.UserPlan, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + appHeaders := AppHeaders{ + UserId: r.Header.Get("X-User-Id"), + PlanId: r.Header.Get("X-Plan-Id"), + } + if appHeaders.UserId == "" || appHeaders.PlanId == "" { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "missing X-User-ID or X-Plan-Id headers"}) return } - c.Set("appHeaders", appHeaders) log.Debug().Interface("appHeaders", appHeaders).Msg("AppHeaderMiddleware") + // --------------------------- + newCtx := context.WithValue(r.Context(), appHeadersKey, appHeaders) // Extract user plan userPlan, ok := userPlans[appHeaders.PlanId] if !ok { errmsg := fmt.Sprintf("unknown user plan %s", appHeaders.PlanId) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errmsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errmsg}) return } - c.Set("userPlan", userPlan) - c.Next() - } + newCtx = context.WithValue(newCtx, userPlanKey, userPlan) + next.ServeHTTP(w, r.WithContext(newCtx)) + }) +} + +func GetAppHeaders(ctx context.Context) AppHeaders { + return ctx.Value(appHeadersKey).(AppHeaders) +} + +func GetUserPlan(ctx context.Context) models.UserPlan { + return ctx.Value(userPlanKey).(models.UserPlan) } diff --git a/httpapi/utils/encdec.go b/httpapi/utils/encdec.go new file mode 100644 index 0000000..c1ff593 --- /dev/null +++ b/httpapi/utils/encdec.go @@ -0,0 +1,50 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "github.com/rs/zerolog/log" +) + +// Validator is an object that can be validated. +type Validator interface { + Validate() error +} + +// Encode writes the object to the response writer. It is usually used as the +// last step in a handler. +func Encode[T any](w http.ResponseWriter, status int, v T) { + w.Header().Set("Content-Type", "application/json") + // Write to buffer first to ensure the object is json encodable + // before writing to the response writer. + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(v); err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("could not encode response") + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + w.WriteHeader(status) + w.Write(buf.Bytes()) +} + +// DecodeValid decodes the request body into the object and then validates it. +// Look at problems to see if there are any issues. +func DecodeValid[T Validator](r *http.Request) (T, error) { + var v T + if r.Header.Get("Content-Type") != "application/json" { + return v, fmt.Errorf("expected content type application/json, got %s", r.Header.Get("Content-Type")) + } + // --------------------------- + if err := json.NewDecoder(r.Body).Decode(&v); err != nil { + return v, fmt.Errorf("decode json: %w", err) + } + if err := v.Validate(); err != nil { + return v, fmt.Errorf("validation error: %w", err) + } + // --------------------------- + return v, nil +} diff --git a/httpapi/v1/handlers.go b/httpapi/v1/handlers.go index 611d1b7..36513de 100644 --- a/httpapi/v1/handlers.go +++ b/httpapi/v1/handlers.go @@ -1,48 +1,53 @@ package v1 import ( + "context" "errors" "fmt" "net/http" "time" - "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" "github.com/vmihailenco/msgpack/v5" ) -func pongHandler(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "message": "pong from semadb", - }) -} - // --------------------------- +func handlePing(w http.ResponseWriter, r *http.Request) { + utils.Encode(w, http.StatusOK, map[string]string{"message": "pong from semadb"}) +} + type SemaDBHandlers struct { clusterNode *cluster.ClusterNode } // Requires middleware.AppHeaderMiddleware to be used -func SetupV1Handlers(clusterNode *cluster.ClusterNode, rgroup *gin.RouterGroup) { - rgroup.GET("/ping", pongHandler) +func SetupV1Handlers(clusterNode *cluster.ClusterNode) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ping", handlePing) semaDBHandlers := &SemaDBHandlers{clusterNode: clusterNode} // https://stackoverflow.blog/2020/03/02/best-practices-for-rest-api-design/ - rgroup.POST("/collections", semaDBHandlers.CreateCollection) - rgroup.GET("/collections", semaDBHandlers.ListCollections) - colRoutes := rgroup.Group("/collections/:collectionId", semaDBHandlers.CollectionURIMiddleware()) - colRoutes.GET("", semaDBHandlers.GetCollection) - colRoutes.DELETE("", semaDBHandlers.DeleteCollection) + mux.HandleFunc("GET /collections", semaDBHandlers.HandleListCollections) + mux.HandleFunc("POST /collections", semaDBHandlers.HandleCreateCollection) + // --------------------------- + withCol := func(next http.HandlerFunc) http.Handler { + return semaDBHandlers.CollectionURIMiddleware(http.HandlerFunc(next)) + } + mux.Handle("GET /collections/{collectionId}", withCol(semaDBHandlers.HandleGetCollection)) + mux.Handle("DELETE /collections/{collectionId}", withCol(semaDBHandlers.HandleDeleteCollection)) // We're batching point requests for peformance reasons. Alternatively we // can provide points/:pointId endpoint in the future. - colRoutes.POST("/points", semaDBHandlers.InsertPoints) - colRoutes.PUT("/points", semaDBHandlers.UpdatePoints) - colRoutes.DELETE("/points", semaDBHandlers.DeletePoints) - colRoutes.POST("/points/search", semaDBHandlers.SearchPoints) + mux.Handle("POST /collections/{collectionId}/points", withCol(semaDBHandlers.HandleInsertPoints)) + mux.Handle("PUT /collections/{collectionId}/points", withCol(semaDBHandlers.HandleUpdatePoints)) + mux.Handle("DELETE /collections/{collectionId}/points", withCol(semaDBHandlers.HandleDeletePoints)) + mux.Handle("POST /collections/{collectionId}/points/search", withCol(semaDBHandlers.HandleSearchPoints)) + // --------------------------- + return mux } // --------------------------- @@ -53,13 +58,33 @@ type CreateCollectionRequest struct { DistanceMetric string `json:"distanceMetric" binding:"required,oneof=euclidean cosine dot"` } -func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { - var req CreateCollectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +func (req CreateCollectionRequest) Validate() error { + if len(req.Id) < 3 || len(req.Id) > 16 { + return fmt.Errorf("id must be between 3 and 16 characters") + } + // Check alphanum + for _, r := range req.Id { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { + return fmt.Errorf("id must be alphanumeric") + } + } + if req.VectorSize < 1 || req.VectorSize > 4096 { + return fmt.Errorf("vectorSize must be between 1 and 4096, got %d", req.VectorSize) + } + if req.DistanceMetric != models.DistanceEuclidean && req.DistanceMetric != models.DistanceCosine && req.DistanceMetric != models.DistanceDot { + return fmt.Errorf("distanceMetric must be one of euclidean, cosine, dot, got %s", req.DistanceMetric) + } + return nil +} + +func (sdbh *SemaDBHandlers) HandleCreateCollection(w http.ResponseWriter, r *http.Request) { + req, err := utils.DecodeValid[CreateCollectionRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) + appHeaders := middleware.GetAppHeaders(r.Context()) + userPlan := middleware.GetUserPlan(r.Context()) // --------------------------- vamanaCollection := models.Collection{ UserId: appHeaders.UserId, @@ -67,7 +92,7 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { Replicas: 1, Timestamp: time.Now().Unix(), CreatedAt: time.Now().Unix(), - UserPlan: c.MustGet("userPlan").(models.UserPlan), + UserPlan: userPlan, IndexSchema: models.IndexSchema{ "vector": { Type: "vectorVamana", @@ -84,17 +109,15 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { } log.Debug().Interface("collection", vamanaCollection).Msg("CreateCollection") // --------------------------- - err := sdbh.clusterNode.CreateCollection(vamanaCollection) - switch err { + switch err := sdbh.clusterNode.CreateCollection(vamanaCollection); err { case nil: - c.JSON(http.StatusOK, gin.H{"message": "collection created"}) + utils.Encode(w, http.StatusOK, map[string]string{"message": "collection created"}) case cluster.ErrQuotaReached: - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) case cluster.ErrExists: - c.JSON(http.StatusConflict, gin.H{"error": "collection exists"}) + utils.Encode(w, http.StatusConflict, map[string]string{"error": "collection exists"}) default: - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Str("id", vamanaCollection.Id).Msg("CreateCollection failed") } } @@ -109,13 +132,12 @@ type ListCollectionsResponse struct { Collections []ListCollectionItem `json:"collections"` } -func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) +func (sdbh *SemaDBHandlers) HandleListCollections(w http.ResponseWriter, r *http.Request) { + appHeaders := middleware.GetAppHeaders(r.Context()) // --------------------------- collections, err := sdbh.clusterNode.ListCollections(appHeaders.UserId) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Msg("ListCollections failed") return } @@ -124,33 +146,33 @@ func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { colItems[i] = ListCollectionItem{Id: col.Id, VectorSize: col.IndexSchema["vector"].VectorVamana.VectorSize, DistanceMetric: col.IndexSchema["vector"].VectorVamana.DistanceMetric} } resp := ListCollectionsResponse{Collections: colItems} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } // --------------------------- -type GetCollectionUri struct { - CollectionId string `uri:"collectionId" binding:"required,alphanum,min=3,max=16"` -} +type contextKey string + +const collectionContextKey contextKey = "collection" -func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var uri GetCollectionUri - if err := c.ShouldBindUri(&uri); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +// Extracts collectionId from the URI and fetches the collection from the cluster. +func (sdbh *SemaDBHandlers) CollectionURIMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + collectionId := r.PathValue("collectionId") + if len(collectionId) < 3 || len(collectionId) > 16 { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "collectionId must be between 3 and 16 characters"}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, uri.CollectionId) + appHeaders := middleware.GetAppHeaders(r.Context()) + collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, collectionId) if err == cluster.ErrNotFound { - errMsg := fmt.Sprintf("collection %s not found", uri.CollectionId) - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": errMsg}) + errMsg := fmt.Sprintf("collection %s not found", collectionId) + utils.Encode(w, http.StatusNotFound, map[string]string{"error": errMsg}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -158,10 +180,11 @@ func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { // This is because the user plan might change and we want the latest // active one rather than the one saved in the collection. This means // any downstream operation will use the latest user plan. - collection.UserPlan = c.MustGet("userPlan").(models.UserPlan) + collection.UserPlan = middleware.GetUserPlan(r.Context()) // --------------------------- - c.Set("collection", collection) - } + newCtx := context.WithValue(r.Context(), collectionContextKey, collection) + next.ServeHTTP(w, r.WithContext(newCtx)) + }) } // --------------------------- @@ -178,18 +201,17 @@ type GetCollectionResponse struct { Shards []ShardItem `json:"shards"` } -func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleGetCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- shards, err := sdbh.clusterNode.GetShardsInfo(collection) if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -203,26 +225,25 @@ func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { DistanceMetric: collection.IndexSchema["vector"].VectorVamana.DistanceMetric, Shards: shardItems, } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- -func (sdbh *SemaDBHandlers) DeleteCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeleteCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- deletedShardIds, err := sdbh.clusterNode.DeleteCollection(collection) if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } status := http.StatusOK if len(deletedShardIds) != len(collection.ShardIds) { status = http.StatusAccepted } - c.JSON(status, gin.H{"message": "collection deleted"}) + utils.Encode(w, status, map[string]string{"message": "collection deleted"}) } // --------------------------- @@ -233,34 +254,58 @@ type InsertSinglePointRequest struct { Metadata any `json:"metadata"` } +func (req InsertSinglePointRequest) Validate() error { + if len(req.Id) > 0 { + if _, err := uuid.Parse(req.Id); err != nil { + return fmt.Errorf("id must be a valid uuid") + } + } + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000, got %d", len(req.Vector)) + } + return nil +} + type InsertPointsRequest struct { Points []InsertSinglePointRequest `json:"points" binding:"required,max=10000,dive"` } +func (req InsertPointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 10000 { + return fmt.Errorf("points size must be between 1 and 10000, got %d", len(req.Points)) + } + for i, point := range req.Points { + if err := point.Validate(); err != nil { + return fmt.Errorf("point at index %d: %w", i, err) + } + } + return nil +} + type InsertPointsResponse struct { Message string `json:"message"` FailedRanges []cluster.FailedRange `json:"failedRanges"` } -func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleInsertPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req InsertPointsRequest startTime := time.Now() - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[InsertPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } log.Debug().Str("bindTime", time.Since(startTime).String()).Msg("InsertPoints bind") // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if len(point.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d for point at index %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(point.Vector), i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } pointId := uuid.New() @@ -271,12 +316,12 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { binaryPointData, err := msgpack.Marshal(pointData) if err != nil { errMsg := fmt.Sprintf("failed to JSON encode point at index %d, please ensure all fields are JSON compatible", i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(binaryPointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(binaryPointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{ @@ -288,23 +333,22 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { // Insert points returns a range of errors for failed shards failedRanges, err := sdbh.clusterNode.InsertPoints(collection, points) if errors.Is(err, cluster.ErrQuotaReached) { - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) return } if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := InsertPointsResponse{Message: "success", FailedRanges: failedRanges} if len(failedRanges) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } @@ -316,32 +360,54 @@ type UpdateSinglePointRequest struct { Metadata any `json:"metadata"` } +func (req UpdateSinglePointRequest) Validate() error { + if _, err := uuid.Parse(req.Id); err != nil { + return fmt.Errorf("id must be a valid uuid, got %s", req.Id) + } + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000, got %d", len(req.Vector)) + } + return nil +} + type UpdatePointsRequest struct { Points []UpdateSinglePointRequest `json:"points" binding:"required,max=100,dive"` } +func (req UpdatePointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 100 { + return fmt.Errorf("points size must be between 1 and 100, got %d", len(req.Points)) + } + for i, point := range req.Points { + if err := point.Validate(); err != nil { + return fmt.Errorf("point at index %d: %w", i, err) + } + } + return nil +} + type UpdatePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleUpdatePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req UpdatePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[UpdatePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if len(point.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d for point at index %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(point.Vector), i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{ @@ -351,12 +417,12 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { binaryPointData, err := msgpack.Marshal(pointData) if err != nil { errMsg := fmt.Sprintf("failed to JSON encode %d, please ensure all fields are JSON compatible", i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(binaryPointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(binaryPointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = binaryPointData @@ -365,15 +431,14 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { // Update points returns a list of failed points failedPoints, err := sdbh.clusterNode.UpdatePoints(collection, points) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := UpdatePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -382,16 +447,28 @@ type DeletePointsRequest struct { Ids []string `json:"ids" binding:"required,max=100,dive,uuid"` } +func (req DeletePointsRequest) Validate() error { + if len(req.Ids) < 1 || len(req.Ids) > 100 { + return fmt.Errorf("ids size must be between 1 and 100, got %d", len(req.Ids)) + } + for i, id := range req.Ids { + if _, err := uuid.Parse(id); err != nil { + return fmt.Errorf("id at index %d must be a valid uuid", i) + } + } + return nil +} + type DeletePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeletePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req DeletePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[DeletePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -402,19 +479,18 @@ func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- failedPoints, err := sdbh.clusterNode.DeletePoints(collection, pointIds) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := DeletePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -424,6 +500,16 @@ type SearchPointsRequest struct { Limit int `json:"limit" binding:"min=0,max=75"` } +func (req SearchPointsRequest) Validate() error { + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000") + } + if req.Limit < 0 || req.Limit > 75 { + return fmt.Errorf("limit must be between 0 and 75") + } + return nil +} + type SearchPointResult struct { Id string `json:"id"` Distance float32 `json:"distance"` @@ -434,11 +520,11 @@ type SearchPointsResponse struct { Points []SearchPointResult `json:"points"` } -func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleSearchPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req SearchPointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[SearchPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // Default limit is 10 @@ -447,12 +533,12 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Check vector dimension if len(req.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(req.Vector)) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } // --------------------------- @@ -471,7 +557,7 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } points, err := sdbh.clusterNode.SearchPoints(collection, sr) if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } results := make([]SearchPointResult, len(points)) @@ -488,6 +574,6 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } } resp := SearchPointsResponse{Points: results} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } diff --git a/httpapi/v2/handlers.go b/httpapi/v2/handlers.go index d122135..f0cfdfa 100644 --- a/httpapi/v2/handlers.go +++ b/httpapi/v2/handlers.go @@ -1,48 +1,53 @@ package v2 import ( + "context" "errors" "fmt" "net/http" "time" - "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" "github.com/vmihailenco/msgpack/v5" ) -func pongHandler(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "message": "pong from semadb", - }) -} - // --------------------------- +func handlePing(w http.ResponseWriter, r *http.Request) { + utils.Encode(w, http.StatusOK, map[string]string{"message": "pong from semadb"}) +} + type SemaDBHandlers struct { clusterNode *cluster.ClusterNode } // Requires middleware.AppHeaderMiddleware to be used -func SetupV2Handlers(clusterNode *cluster.ClusterNode, rgroup *gin.RouterGroup) { - rgroup.GET("/ping", pongHandler) +func SetupV2Handlers(clusterNode *cluster.ClusterNode) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ping", handlePing) semaDBHandlers := &SemaDBHandlers{clusterNode: clusterNode} // https://stackoverflow.blog/2020/03/02/best-practices-for-rest-api-design/ - rgroup.POST("/collections", semaDBHandlers.CreateCollection) - rgroup.GET("/collections", semaDBHandlers.ListCollections) - colRoutes := rgroup.Group("/collections/:collectionId", semaDBHandlers.CollectionURIMiddleware()) - colRoutes.GET("", semaDBHandlers.GetCollection) - colRoutes.DELETE("", semaDBHandlers.DeleteCollection) + mux.HandleFunc("GET /collections", semaDBHandlers.HandleListCollections) + mux.HandleFunc("POST /collections", semaDBHandlers.HandleCreateCollection) + // --------------------------- + withCol := func(next http.HandlerFunc) http.Handler { + return semaDBHandlers.CollectionURIMiddleware(http.HandlerFunc(next)) + } + mux.Handle("GET /collections/{collectionId}", withCol(semaDBHandlers.HandleGetCollection)) + mux.Handle("DELETE /collections/{collectionId}", withCol(semaDBHandlers.HandleDeleteCollection)) // We're batching point requests for peformance reasons. Alternatively we // can provide points/:pointId endpoint in the future. - colRoutes.POST("/points", semaDBHandlers.InsertPoints) - colRoutes.PUT("/points", semaDBHandlers.UpdatePoints) - colRoutes.DELETE("/points", semaDBHandlers.DeletePoints) - colRoutes.POST("/points/search", semaDBHandlers.SearchPoints) + mux.Handle("POST /collections/{collectionId}/points", withCol(semaDBHandlers.HandleInsertPoints)) + mux.Handle("PUT /collections/{collectionId}/points", withCol(semaDBHandlers.HandleUpdatePoints)) + mux.Handle("DELETE /collections/{collectionId}/points", withCol(semaDBHandlers.HandleDeletePoints)) + mux.Handle("POST /collections/{collectionId}/points/search", withCol(semaDBHandlers.HandleSearchPoints)) + // --------------------------- + return mux } // --------------------------- @@ -52,18 +57,28 @@ type CreateCollectionRequest struct { IndexSchema models.IndexSchema `json:"indexSchema" binding:"required,dive"` } -func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { - var req CreateCollectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return +func (req CreateCollectionRequest) Validate() error { + if len(req.Id) < 3 || len(req.Id) > 24 { + return fmt.Errorf("id must be between 3 and 24 characters, got %d", len(req.Id)) } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - // --------------------------- - if err := req.IndexSchema.Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + // Check alpanum + for _, r := range req.Id { + if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9')) { + return fmt.Errorf("id must be alphanumeric, got %s", req.Id) + } + } + // Validate index schema + return req.IndexSchema.Validate() +} + +func (sdbh *SemaDBHandlers) HandleCreateCollection(w http.ResponseWriter, r *http.Request) { + req, err := utils.DecodeValid[CreateCollectionRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } + appHeaders := middleware.GetAppHeaders(r.Context()) + userPlan := middleware.GetUserPlan(r.Context()) // --------------------------- vamanaCollection := models.Collection{ UserId: appHeaders.UserId, @@ -71,22 +86,20 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { Replicas: 1, Timestamp: time.Now().Unix(), CreatedAt: time.Now().Unix(), - UserPlan: c.MustGet("userPlan").(models.UserPlan), + UserPlan: userPlan, IndexSchema: req.IndexSchema, } log.Debug().Interface("collection", vamanaCollection).Msg("CreateCollection") // --------------------------- - err := sdbh.clusterNode.CreateCollection(vamanaCollection) - switch err { + switch err := sdbh.clusterNode.CreateCollection(vamanaCollection); err { case nil: - c.JSON(http.StatusOK, gin.H{"message": "collection created"}) + utils.Encode(w, http.StatusOK, map[string]string{"message": "collection created"}) case cluster.ErrQuotaReached: - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) case cluster.ErrExists: - c.JSON(http.StatusConflict, gin.H{"error": "collection exists"}) + utils.Encode(w, http.StatusConflict, map[string]string{"error": "collection exists"}) default: - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Str("id", vamanaCollection.Id).Msg("CreateCollection failed") } } @@ -99,13 +112,12 @@ type ListCollectionsResponse struct { Collections []ListCollectionItem `json:"collections"` } -func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) +func (sdbh *SemaDBHandlers) HandleListCollections(w http.ResponseWriter, r *http.Request) { + appHeaders := middleware.GetAppHeaders(r.Context()) // --------------------------- collections, err := sdbh.clusterNode.ListCollections(appHeaders.UserId) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Msg("ListCollections failed") return } @@ -114,33 +126,32 @@ func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { colItems[i] = ListCollectionItem{Id: col.Id} } resp := ListCollectionsResponse{Collections: colItems} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } // --------------------------- -type GetCollectionUri struct { - CollectionId string `uri:"collectionId" binding:"required,alphanum,min=3,max=24"` -} +type contextKey string -func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var uri GetCollectionUri - if err := c.ShouldBindUri(&uri); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +const collectionContextKey contextKey = "collection" + +func (sdbh *SemaDBHandlers) CollectionURIMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + collectionId := r.PathValue("collectionId") + if len(collectionId) < 3 || len(collectionId) > 24 { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "collectionId must be between 3 and 24 characters"}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, uri.CollectionId) + appHeaders := middleware.GetAppHeaders(r.Context()) + collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, collectionId) if err == cluster.ErrNotFound { - errMsg := fmt.Sprintf("collection %s not found", uri.CollectionId) - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": errMsg}) + errMsg := fmt.Sprintf("collection %s not found", collectionId) + utils.Encode(w, http.StatusNotFound, map[string]string{"error": errMsg}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -148,10 +159,11 @@ func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { // This is because the user plan might change and we want the latest // active one rather than the one saved in the collection. This means // any downstream operation will use the latest user plan. - collection.UserPlan = c.MustGet("userPlan").(models.UserPlan) + collection.UserPlan = middleware.GetUserPlan(r.Context()) // --------------------------- - c.Set("collection", collection) - } + newReq := r.WithContext(context.WithValue(r.Context(), collectionContextKey, collection)) + next.ServeHTTP(w, newReq) + }) } // --------------------------- @@ -167,18 +179,17 @@ type GetCollectionResponse struct { Shards []ShardItem `json:"shards"` } -func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleGetCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- shards, err := sdbh.clusterNode.GetShardsInfo(collection) if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -191,26 +202,25 @@ func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { IndexSchema: collection.IndexSchema, Shards: shardItems, } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- -func (sdbh *SemaDBHandlers) DeleteCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeleteCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- deletedShardIds, err := sdbh.clusterNode.DeleteCollection(collection) if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } status := http.StatusOK if len(deletedShardIds) != len(collection.ShardIds) { status = http.StatusAccepted } - c.JSON(status, gin.H{"message": "collection deleted"}) + utils.Encode(w, status, map[string]string{"message": "collection deleted"}) } // --------------------------- @@ -219,47 +229,54 @@ type InsertPointsRequest struct { Points []models.PointAsMap `json:"points" binding:"required,max=10000"` } +func (req InsertPointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 10000 { + return fmt.Errorf("number of points must be between 1 and 10000, got %d", len(req.Points)) + } + return nil +} + type InsertPointsResponse struct { Message string `json:"message"` FailedRanges []cluster.FailedRange `json:"failedRanges"` } -func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleInsertPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req InsertPointsRequest startTime := time.Now() - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[InsertPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } log.Debug().Str("bindTime", time.Since(startTime).String()).Msg("InsertPoints bind") // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if err := collection.IndexSchema.CheckCompatibleMap(point); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } pointId, err := point.ExtractIdField(true) if err != nil { errMsg := fmt.Sprintf("invalid id for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{Id: pointId} pointData, err := msgpack.Marshal(point) if err != nil { errMsg := fmt.Sprintf("invalid point data for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(pointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(pointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = pointData @@ -268,23 +285,22 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { // Insert points returns a range of errors for failed shards failedRanges, err := sdbh.clusterNode.InsertPoints(collection, points) if errors.Is(err, cluster.ErrQuotaReached) { - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) return } if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := InsertPointsResponse{Message: "success", FailedRanges: failedRanges} if len(failedRanges) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } @@ -294,21 +310,28 @@ type UpdatePointsRequest struct { Points []models.PointAsMap `json:"points" binding:"required,max=100"` } +func (req UpdatePointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 100 { + return fmt.Errorf("number of points must be between 1 and 100, got %d", len(req.Points)) + } + return nil +} + type UpdatePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleUpdatePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req UpdatePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[UpdatePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) @@ -316,23 +339,23 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { pointId, err := point.ExtractIdField(false) if err != nil { errMsg := fmt.Sprintf("invalid id for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if err := collection.IndexSchema.CheckCompatibleMap(point); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } points[i] = models.Point{Id: pointId} pointData, err := msgpack.Marshal(point) if err != nil { errMsg := fmt.Sprintf("invalid point data for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(pointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(pointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = pointData @@ -341,15 +364,14 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { // Update points returns a list of failed points failedPoints, err := sdbh.clusterNode.UpdatePoints(collection, points) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := UpdatePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -358,16 +380,28 @@ type DeletePointsRequest struct { Ids []string `json:"ids" binding:"required,max=100,dive,uuid"` } +func (req DeletePointsRequest) Validate() error { + if len(req.Ids) < 1 || len(req.Ids) > 100 { + return fmt.Errorf("number of ids must be between 1 and 100, got %d", len(req.Ids)) + } + for i, id := range req.Ids { + if _, err := uuid.Parse(id); err != nil { + return fmt.Errorf("invalid uuid at index %d", i) + } + } + return nil +} + type DeletePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeletePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req DeletePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[DeletePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -378,19 +412,18 @@ func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- failedPoints, err := sdbh.clusterNode.DeletePoints(collection, pointIds) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := DeletePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -399,11 +432,11 @@ type SearchPointsResponse struct { Points []models.PointAsMap `json:"points"` } -func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleSearchPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req models.SearchRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[models.SearchRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // Default limit is 10 @@ -412,17 +445,17 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Validate query against schema, checks vector dimensions, query options etc. - if err := req.Query.Validate(collection.IndexSchema); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + if err := req.Query.ValidateSchema(collection.IndexSchema); err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- points, err := sdbh.clusterNode.SearchPoints(collection, req) if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } results := make([]models.PointAsMap, len(points)) @@ -433,7 +466,7 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { if len(sp.Point.Data) > 0 { if err := msgpack.Unmarshal(sp.Point.Data, &pointData); err != nil { errMsg := fmt.Sprintf("could not decode point %s", sp.Point.Id.String()) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": errMsg}) return } } @@ -451,6 +484,6 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { results[i] = pointData } resp := SearchPointsResponse{Points: results} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- }