diff --git a/httpapi/httpapi.go b/httpapi/httpapi.go index e0a5e92..da0a643 100644 --- a/httpapi/httpapi.go +++ b/httpapi/httpapi.go @@ -4,7 +4,6 @@ import ( "net/http" "strconv" - "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" @@ -34,39 +33,35 @@ type HttpApiConfig struct { // --------------------------- -func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *gin.Engine { - router := gin.New() +func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) http.Handler { // --------------------------- var metrics *httpMetrics if cfg.EnableMetrics && reg != nil { metrics = setupAndListenMetrics(cfg, reg) } // --------------------------- - router.Use(ZerologLoggerMetrics(metrics), gin.Recovery()) + mux := http.NewServeMux() + mux.Handle("/v1/", http.StripPrefix("/v1", httpv1.SetupV1Handlers(cnode))) + mux.Handle("/v2/", http.StripPrefix("/v2", httpv2.SetupV2Handlers(cnode))) // --------------------------- - if len(cfg.ProxySecret) > 0 { - log.Info().Msg("ProxySecretMiddleware is enabled") - router.Use(ProxySecretMiddleware(cfg.ProxySecret)) - } + var handler http.Handler = mux + handler = middleware.AppHeaderMiddleware(cfg.UserPlans, handler) if cfg.WhiteListIPs == nil || (len(cfg.WhiteListIPs) == 1 && cfg.WhiteListIPs[0] == "*") { log.Warn().Strs("whiteListIPs", cfg.WhiteListIPs).Msg("WhiteListIPMiddleware is disabled") } else { - router.Use(WhiteListIPMiddleware(cfg.WhiteListIPs)) + handler = WhiteListIPMiddleware(cfg.WhiteListIPs, handler) } + if len(cfg.ProxySecret) > 0 { + log.Info().Msg("ProxySecretMiddleware is enabled") + handler = ProxySecretMiddleware(cfg.ProxySecret, handler) + } + handler = ZeroLoggerMetrics(metrics, handler) + handler = RecoverMiddleware(handler) // --------------------------- - v1 := router.Group("/v1", middleware.AppHeaderMiddleware(cfg.UserPlans)) - httpv1.SetupV1Handlers(cnode, v1) - // --------------------------- - v2 := router.Group("/v2", middleware.AppHeaderMiddleware(cfg.UserPlans)) - httpv2.SetupV2Handlers(cnode, v2) - return router + return handler } func RunHTTPServer(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *http.Server { - // --------------------------- - if !cfg.Debug { - gin.SetMode(gin.ReleaseMode) - } // --------------------------- server := &http.Server{ Addr: cfg.HttpHost + ":" + strconv.Itoa(cfg.HttpPort), diff --git a/httpapi/metrics.go b/httpapi/metrics.go index adce472..ce18a81 100644 --- a/httpapi/metrics.go +++ b/httpapi/metrics.go @@ -14,7 +14,6 @@ type httpMetrics struct { requestCount *prometheus.CounterVec requestDuration *prometheus.HistogramVec requestSize *prometheus.HistogramVec - responseSize *prometheus.HistogramVec // --------------------------- } @@ -44,19 +43,10 @@ func setupAndListenMetrics(cfg HttpApiConfig, reg *prometheus.Registry) *httpMet }, []string{"code", "method", "handler"}, ), - responseSize: prometheus.NewHistogramVec( - prometheus.HistogramOpts{ - Name: "http_response_size_bytes", - Help: "HTTP request sizes in bytes.", - Buckets: []float64{0, 1 << 10, 1 << 15, 1 << 20}, - }, - []string{"code", "method", "handler"}, - ), } reg.MustRegister(metrics.requestCount) reg.MustRegister(metrics.requestDuration) reg.MustRegister(metrics.requestSize) - reg.MustRegister(metrics.responseSize) // --------------------------- mux := http.NewServeMux() mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{Registry: reg})) diff --git a/httpapi/middleware.go b/httpapi/middleware.go index 5ea498d..74b9eaa 100644 --- a/httpapi/middleware.go +++ b/httpapi/middleware.go @@ -2,102 +2,79 @@ package httpapi import ( "net/http" + "regexp" + "runtime/debug" "slices" "strconv" - "strings" "time" - "github.com/gin-gonic/gin" - "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "github.com/rs/zerolog/log" - "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" ) // --------------------------- - -func ZerologLoggerMetrics(metrics *httpMetrics) gin.HandlerFunc { - return func(c *gin.Context) { - // --------------------------- - // Start timer - start := time.Now() - path := c.Request.URL.Path - raw := c.Request.URL.RawQuery - reqSize := c.Request.ContentLength - // --------------------------- - // Process request - c.Next() - // --------------------------- - // Stop timer and gather information - latency := time.Since(start) - - method := c.Request.Method - statusCode := c.Writer.Status() - lastError := c.Errors.ByType(gin.ErrorTypePrivate).Last() - - bodySize := c.Writer.Size() - - if raw != "" { - path = path + "?" + raw - } - // --------------------------- - var logEvent *zerolog.Event - if statusCode == 500 || lastError != nil { - logEvent = log.Error() - } else { - logEvent = log.Info() - } - logEvent.Err(lastError). - Dur("latency", latency). - Str("clientIP", c.ClientIP()). - Str("remoteIP", c.RemoteIP()). - Str("method", method).Str("path", path). - Int("statusCode", statusCode). - Int64("requestSize", reqSize). - Int("bodySize", bodySize). - Str("path", path) - // Extract app headers if any - appH, ok := c.Keys["appHeaders"] - if ok { - appHeaders := appH.(middleware.AppHeaders) - // We are not logging the user ID for privacy reasons - logEvent = logEvent.Str("planId", appHeaders.PlanId) - } - logEvent.Msg("HTTPAPI") - // --------------------------- +// Zerolog based middleware for logging HTTP requests +func ZeroLoggerMetrics(metrics *httpMetrics, next http.Handler) http.Handler { + handler := hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) { + hlog.FromRequest(r).Info(). + Str("method", r.Method). + Stringer("url", r.URL). + Int("status", status). + Int("size", size). + Dur("duration", duration). + Msg("") if metrics != nil { - // Example handler names - // github.com/semafind/semadb/httpapi.(*SemaDBHandlers).ListCollections-fm - // github.com/semafind/semadb/httpapi.(*SemaDBHandlers).CreateCollection-fm - fullHName := c.HandlerName() - parts := strings.Split(fullHName, ".") - hname := parts[len(parts)-1][:len(parts[len(parts)-1])-3] - ssCode := strconv.Itoa(statusCode) - metrics.requestCount.WithLabelValues(ssCode, method, hname).Inc() - metrics.requestDuration.WithLabelValues(ssCode, method, hname).Observe(latency.Seconds()) - metrics.requestSize.WithLabelValues(ssCode, method, hname).Observe(float64(reqSize)) - metrics.responseSize.WithLabelValues(ssCode, method, hname).Observe(float64(bodySize)) + // Canonicalize the URL by removing url parameters + // Replace anything of the form collections/mycol23 with collections/:id + re := regexp.MustCompile(`collections/[a-zA-Z0-9]+`) + canonical := re.ReplaceAll([]byte(r.URL.Path), []byte("collections/{collectionId}")) + hname := string(canonical) + ssCode := strconv.Itoa(status) + metrics.requestCount.WithLabelValues(ssCode, r.Method, hname).Inc() + metrics.requestDuration.WithLabelValues(ssCode, r.Method, hname).Observe(duration.Seconds()) + metrics.requestSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(size)) + // metrics.responseSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(bodySize)) } - } + })(next) + handler = hlog.NewHandler(log.Logger)(handler) + return handler } // --------------------------- -func ProxySecretMiddleware(secret string) gin.HandlerFunc { +func ProxySecretMiddleware(secret string, next http.Handler) http.Handler { log.Debug().Str("proxySecret", secret).Msg("ProxySecretMiddleware") - return func(c *gin.Context) { - if c.GetHeader("X-Proxy-Secret") != secret { - c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{"error": "forbidden"}) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("X-Proxy-Secret") != secret { + utils.Encode(w, http.StatusProxyAuthRequired, map[string]string{"error": "forbidden"}) + return } - } + next.ServeHTTP(w, r) + }) } -func WhiteListIPMiddleware(whitelist []string) gin.HandlerFunc { +func WhiteListIPMiddleware(whitelist []string, next http.Handler) http.Handler { slices.Sort(whitelist) - return func(c *gin.Context) { - remoteIP := c.RemoteIP() - _, found := slices.BinarySearch(whitelist, remoteIP) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, found := slices.BinarySearch(whitelist, r.RemoteAddr) if !found { - c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "forbidden"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "forbidden"}) + return } - } + next.ServeHTTP(w, r) + }) +} + +func RecoverMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Error().Interface("error", err).Msg("panic recovered") + log.Error().Str("stack", string(debug.Stack())).Msg("stack trace") + w.WriteHeader(http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) } diff --git a/httpapi/middleware/appheaders.go b/httpapi/middleware/appheaders.go index 10c2f50..f85cefe 100644 --- a/httpapi/middleware/appheaders.go +++ b/httpapi/middleware/appheaders.go @@ -1,11 +1,12 @@ package middleware import ( + "context" "fmt" "net/http" - "github.com/gin-gonic/gin" "github.com/rs/zerolog/log" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" ) @@ -14,23 +15,40 @@ type AppHeaders struct { PlanId string `header:"X-Plan-Id" binding:"required"` } -func AppHeaderMiddleware(userPlans map[string]models.UserPlan) gin.HandlerFunc { - return func(c *gin.Context) { - var appHeaders AppHeaders - if err := c.ShouldBindHeader(&appHeaders); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +type contextKey string + +const appHeadersKey contextKey = "appHeaders" +const userPlanKey contextKey = "userPlan" + +func AppHeaderMiddleware(userPlans map[string]models.UserPlan, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + appHeaders := AppHeaders{ + UserId: r.Header.Get("X-User-Id"), + PlanId: r.Header.Get("X-Plan-Id"), + } + if appHeaders.UserId == "" || appHeaders.PlanId == "" { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "missing X-User-ID or X-Plan-Id headers"}) return } - c.Set("appHeaders", appHeaders) log.Debug().Interface("appHeaders", appHeaders).Msg("AppHeaderMiddleware") + // --------------------------- + newCtx := context.WithValue(r.Context(), appHeadersKey, appHeaders) // Extract user plan userPlan, ok := userPlans[appHeaders.PlanId] if !ok { errmsg := fmt.Sprintf("unknown user plan %s", appHeaders.PlanId) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errmsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errmsg}) return } - c.Set("userPlan", userPlan) - c.Next() - } + newCtx = context.WithValue(newCtx, userPlanKey, userPlan) + next.ServeHTTP(w, r.WithContext(newCtx)) + }) +} + +func GetAppHeaders(ctx context.Context) AppHeaders { + return ctx.Value(appHeadersKey).(AppHeaders) +} + +func GetUserPlan(ctx context.Context) models.UserPlan { + return ctx.Value(userPlanKey).(models.UserPlan) } diff --git a/httpapi/utils/encdec.go b/httpapi/utils/encdec.go new file mode 100644 index 0000000..c1ff593 --- /dev/null +++ b/httpapi/utils/encdec.go @@ -0,0 +1,50 @@ +package utils + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "github.com/rs/zerolog/log" +) + +// Validator is an object that can be validated. +type Validator interface { + Validate() error +} + +// Encode writes the object to the response writer. It is usually used as the +// last step in a handler. +func Encode[T any](w http.ResponseWriter, status int, v T) { + w.Header().Set("Content-Type", "application/json") + // Write to buffer first to ensure the object is json encodable + // before writing to the response writer. + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(v); err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Error().Err(err).Msg("could not encode response") + json.NewEncoder(w).Encode(map[string]string{"error": err.Error()}) + return + } + w.WriteHeader(status) + w.Write(buf.Bytes()) +} + +// DecodeValid decodes the request body into the object and then validates it. +// Look at problems to see if there are any issues. +func DecodeValid[T Validator](r *http.Request) (T, error) { + var v T + if r.Header.Get("Content-Type") != "application/json" { + return v, fmt.Errorf("expected content type application/json, got %s", r.Header.Get("Content-Type")) + } + // --------------------------- + if err := json.NewDecoder(r.Body).Decode(&v); err != nil { + return v, fmt.Errorf("decode json: %w", err) + } + if err := v.Validate(); err != nil { + return v, fmt.Errorf("validation error: %w", err) + } + // --------------------------- + return v, nil +} diff --git a/httpapi/v1/handlers.go b/httpapi/v1/handlers.go index 611d1b7..36513de 100644 --- a/httpapi/v1/handlers.go +++ b/httpapi/v1/handlers.go @@ -1,48 +1,53 @@ package v1 import ( + "context" "errors" "fmt" "net/http" "time" - "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" "github.com/vmihailenco/msgpack/v5" ) -func pongHandler(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "message": "pong from semadb", - }) -} - // --------------------------- +func handlePing(w http.ResponseWriter, r *http.Request) { + utils.Encode(w, http.StatusOK, map[string]string{"message": "pong from semadb"}) +} + type SemaDBHandlers struct { clusterNode *cluster.ClusterNode } // Requires middleware.AppHeaderMiddleware to be used -func SetupV1Handlers(clusterNode *cluster.ClusterNode, rgroup *gin.RouterGroup) { - rgroup.GET("/ping", pongHandler) +func SetupV1Handlers(clusterNode *cluster.ClusterNode) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ping", handlePing) semaDBHandlers := &SemaDBHandlers{clusterNode: clusterNode} // https://stackoverflow.blog/2020/03/02/best-practices-for-rest-api-design/ - rgroup.POST("/collections", semaDBHandlers.CreateCollection) - rgroup.GET("/collections", semaDBHandlers.ListCollections) - colRoutes := rgroup.Group("/collections/:collectionId", semaDBHandlers.CollectionURIMiddleware()) - colRoutes.GET("", semaDBHandlers.GetCollection) - colRoutes.DELETE("", semaDBHandlers.DeleteCollection) + mux.HandleFunc("GET /collections", semaDBHandlers.HandleListCollections) + mux.HandleFunc("POST /collections", semaDBHandlers.HandleCreateCollection) + // --------------------------- + withCol := func(next http.HandlerFunc) http.Handler { + return semaDBHandlers.CollectionURIMiddleware(http.HandlerFunc(next)) + } + mux.Handle("GET /collections/{collectionId}", withCol(semaDBHandlers.HandleGetCollection)) + mux.Handle("DELETE /collections/{collectionId}", withCol(semaDBHandlers.HandleDeleteCollection)) // We're batching point requests for peformance reasons. Alternatively we // can provide points/:pointId endpoint in the future. - colRoutes.POST("/points", semaDBHandlers.InsertPoints) - colRoutes.PUT("/points", semaDBHandlers.UpdatePoints) - colRoutes.DELETE("/points", semaDBHandlers.DeletePoints) - colRoutes.POST("/points/search", semaDBHandlers.SearchPoints) + mux.Handle("POST /collections/{collectionId}/points", withCol(semaDBHandlers.HandleInsertPoints)) + mux.Handle("PUT /collections/{collectionId}/points", withCol(semaDBHandlers.HandleUpdatePoints)) + mux.Handle("DELETE /collections/{collectionId}/points", withCol(semaDBHandlers.HandleDeletePoints)) + mux.Handle("POST /collections/{collectionId}/points/search", withCol(semaDBHandlers.HandleSearchPoints)) + // --------------------------- + return mux } // --------------------------- @@ -53,13 +58,33 @@ type CreateCollectionRequest struct { DistanceMetric string `json:"distanceMetric" binding:"required,oneof=euclidean cosine dot"` } -func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { - var req CreateCollectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +func (req CreateCollectionRequest) Validate() error { + if len(req.Id) < 3 || len(req.Id) > 16 { + return fmt.Errorf("id must be between 3 and 16 characters") + } + // Check alphanum + for _, r := range req.Id { + if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { + return fmt.Errorf("id must be alphanumeric") + } + } + if req.VectorSize < 1 || req.VectorSize > 4096 { + return fmt.Errorf("vectorSize must be between 1 and 4096, got %d", req.VectorSize) + } + if req.DistanceMetric != models.DistanceEuclidean && req.DistanceMetric != models.DistanceCosine && req.DistanceMetric != models.DistanceDot { + return fmt.Errorf("distanceMetric must be one of euclidean, cosine, dot, got %s", req.DistanceMetric) + } + return nil +} + +func (sdbh *SemaDBHandlers) HandleCreateCollection(w http.ResponseWriter, r *http.Request) { + req, err := utils.DecodeValid[CreateCollectionRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) + appHeaders := middleware.GetAppHeaders(r.Context()) + userPlan := middleware.GetUserPlan(r.Context()) // --------------------------- vamanaCollection := models.Collection{ UserId: appHeaders.UserId, @@ -67,7 +92,7 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { Replicas: 1, Timestamp: time.Now().Unix(), CreatedAt: time.Now().Unix(), - UserPlan: c.MustGet("userPlan").(models.UserPlan), + UserPlan: userPlan, IndexSchema: models.IndexSchema{ "vector": { Type: "vectorVamana", @@ -84,17 +109,15 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { } log.Debug().Interface("collection", vamanaCollection).Msg("CreateCollection") // --------------------------- - err := sdbh.clusterNode.CreateCollection(vamanaCollection) - switch err { + switch err := sdbh.clusterNode.CreateCollection(vamanaCollection); err { case nil: - c.JSON(http.StatusOK, gin.H{"message": "collection created"}) + utils.Encode(w, http.StatusOK, map[string]string{"message": "collection created"}) case cluster.ErrQuotaReached: - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) case cluster.ErrExists: - c.JSON(http.StatusConflict, gin.H{"error": "collection exists"}) + utils.Encode(w, http.StatusConflict, map[string]string{"error": "collection exists"}) default: - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Str("id", vamanaCollection.Id).Msg("CreateCollection failed") } } @@ -109,13 +132,12 @@ type ListCollectionsResponse struct { Collections []ListCollectionItem `json:"collections"` } -func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) +func (sdbh *SemaDBHandlers) HandleListCollections(w http.ResponseWriter, r *http.Request) { + appHeaders := middleware.GetAppHeaders(r.Context()) // --------------------------- collections, err := sdbh.clusterNode.ListCollections(appHeaders.UserId) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Msg("ListCollections failed") return } @@ -124,33 +146,33 @@ func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { colItems[i] = ListCollectionItem{Id: col.Id, VectorSize: col.IndexSchema["vector"].VectorVamana.VectorSize, DistanceMetric: col.IndexSchema["vector"].VectorVamana.DistanceMetric} } resp := ListCollectionsResponse{Collections: colItems} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } // --------------------------- -type GetCollectionUri struct { - CollectionId string `uri:"collectionId" binding:"required,alphanum,min=3,max=16"` -} +type contextKey string + +const collectionContextKey contextKey = "collection" -func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var uri GetCollectionUri - if err := c.ShouldBindUri(&uri); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +// Extracts collectionId from the URI and fetches the collection from the cluster. +func (sdbh *SemaDBHandlers) CollectionURIMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + collectionId := r.PathValue("collectionId") + if len(collectionId) < 3 || len(collectionId) > 16 { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "collectionId must be between 3 and 16 characters"}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, uri.CollectionId) + appHeaders := middleware.GetAppHeaders(r.Context()) + collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, collectionId) if err == cluster.ErrNotFound { - errMsg := fmt.Sprintf("collection %s not found", uri.CollectionId) - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": errMsg}) + errMsg := fmt.Sprintf("collection %s not found", collectionId) + utils.Encode(w, http.StatusNotFound, map[string]string{"error": errMsg}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -158,10 +180,11 @@ func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { // This is because the user plan might change and we want the latest // active one rather than the one saved in the collection. This means // any downstream operation will use the latest user plan. - collection.UserPlan = c.MustGet("userPlan").(models.UserPlan) + collection.UserPlan = middleware.GetUserPlan(r.Context()) // --------------------------- - c.Set("collection", collection) - } + newCtx := context.WithValue(r.Context(), collectionContextKey, collection) + next.ServeHTTP(w, r.WithContext(newCtx)) + }) } // --------------------------- @@ -178,18 +201,17 @@ type GetCollectionResponse struct { Shards []ShardItem `json:"shards"` } -func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleGetCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- shards, err := sdbh.clusterNode.GetShardsInfo(collection) if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -203,26 +225,25 @@ func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { DistanceMetric: collection.IndexSchema["vector"].VectorVamana.DistanceMetric, Shards: shardItems, } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- -func (sdbh *SemaDBHandlers) DeleteCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeleteCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- deletedShardIds, err := sdbh.clusterNode.DeleteCollection(collection) if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } status := http.StatusOK if len(deletedShardIds) != len(collection.ShardIds) { status = http.StatusAccepted } - c.JSON(status, gin.H{"message": "collection deleted"}) + utils.Encode(w, status, map[string]string{"message": "collection deleted"}) } // --------------------------- @@ -233,34 +254,58 @@ type InsertSinglePointRequest struct { Metadata any `json:"metadata"` } +func (req InsertSinglePointRequest) Validate() error { + if len(req.Id) > 0 { + if _, err := uuid.Parse(req.Id); err != nil { + return fmt.Errorf("id must be a valid uuid") + } + } + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000, got %d", len(req.Vector)) + } + return nil +} + type InsertPointsRequest struct { Points []InsertSinglePointRequest `json:"points" binding:"required,max=10000,dive"` } +func (req InsertPointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 10000 { + return fmt.Errorf("points size must be between 1 and 10000, got %d", len(req.Points)) + } + for i, point := range req.Points { + if err := point.Validate(); err != nil { + return fmt.Errorf("point at index %d: %w", i, err) + } + } + return nil +} + type InsertPointsResponse struct { Message string `json:"message"` FailedRanges []cluster.FailedRange `json:"failedRanges"` } -func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleInsertPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req InsertPointsRequest startTime := time.Now() - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[InsertPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } log.Debug().Str("bindTime", time.Since(startTime).String()).Msg("InsertPoints bind") // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if len(point.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d for point at index %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(point.Vector), i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } pointId := uuid.New() @@ -271,12 +316,12 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { binaryPointData, err := msgpack.Marshal(pointData) if err != nil { errMsg := fmt.Sprintf("failed to JSON encode point at index %d, please ensure all fields are JSON compatible", i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(binaryPointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(binaryPointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{ @@ -288,23 +333,22 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { // Insert points returns a range of errors for failed shards failedRanges, err := sdbh.clusterNode.InsertPoints(collection, points) if errors.Is(err, cluster.ErrQuotaReached) { - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) return } if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := InsertPointsResponse{Message: "success", FailedRanges: failedRanges} if len(failedRanges) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } @@ -316,32 +360,54 @@ type UpdateSinglePointRequest struct { Metadata any `json:"metadata"` } +func (req UpdateSinglePointRequest) Validate() error { + if _, err := uuid.Parse(req.Id); err != nil { + return fmt.Errorf("id must be a valid uuid, got %s", req.Id) + } + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000, got %d", len(req.Vector)) + } + return nil +} + type UpdatePointsRequest struct { Points []UpdateSinglePointRequest `json:"points" binding:"required,max=100,dive"` } +func (req UpdatePointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 100 { + return fmt.Errorf("points size must be between 1 and 100, got %d", len(req.Points)) + } + for i, point := range req.Points { + if err := point.Validate(); err != nil { + return fmt.Errorf("point at index %d: %w", i, err) + } + } + return nil +} + type UpdatePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleUpdatePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req UpdatePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[UpdatePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if len(point.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d for point at index %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(point.Vector), i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{ @@ -351,12 +417,12 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { binaryPointData, err := msgpack.Marshal(pointData) if err != nil { errMsg := fmt.Sprintf("failed to JSON encode %d, please ensure all fields are JSON compatible", i) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(binaryPointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(binaryPointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = binaryPointData @@ -365,15 +431,14 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { // Update points returns a list of failed points failedPoints, err := sdbh.clusterNode.UpdatePoints(collection, points) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := UpdatePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -382,16 +447,28 @@ type DeletePointsRequest struct { Ids []string `json:"ids" binding:"required,max=100,dive,uuid"` } +func (req DeletePointsRequest) Validate() error { + if len(req.Ids) < 1 || len(req.Ids) > 100 { + return fmt.Errorf("ids size must be between 1 and 100, got %d", len(req.Ids)) + } + for i, id := range req.Ids { + if _, err := uuid.Parse(id); err != nil { + return fmt.Errorf("id at index %d must be a valid uuid", i) + } + } + return nil +} + type DeletePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeletePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req DeletePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[DeletePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -402,19 +479,18 @@ func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- failedPoints, err := sdbh.clusterNode.DeletePoints(collection, pointIds) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := DeletePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -424,6 +500,16 @@ type SearchPointsRequest struct { Limit int `json:"limit" binding:"min=0,max=75"` } +func (req SearchPointsRequest) Validate() error { + if len(req.Vector) < 1 || len(req.Vector) > 2000 { + return fmt.Errorf("vector size must be between 1 and 2000") + } + if req.Limit < 0 || req.Limit > 75 { + return fmt.Errorf("limit must be between 0 and 75") + } + return nil +} + type SearchPointResult struct { Id string `json:"id"` Distance float32 `json:"distance"` @@ -434,11 +520,11 @@ type SearchPointsResponse struct { Points []SearchPointResult `json:"points"` } -func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleSearchPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req SearchPointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[SearchPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // Default limit is 10 @@ -447,12 +533,12 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Check vector dimension if len(req.Vector) != int(collection.IndexSchema["vector"].VectorVamana.VectorSize) { errMsg := fmt.Sprintf("invalid vector dimension, expected %d got %d", collection.IndexSchema["vector"].VectorVamana.VectorSize, len(req.Vector)) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } // --------------------------- @@ -471,7 +557,7 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } points, err := sdbh.clusterNode.SearchPoints(collection, sr) if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } results := make([]SearchPointResult, len(points)) @@ -488,6 +574,6 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } } resp := SearchPointsResponse{Points: results} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } diff --git a/httpapi/v2/handlers.go b/httpapi/v2/handlers.go index d122135..f0cfdfa 100644 --- a/httpapi/v2/handlers.go +++ b/httpapi/v2/handlers.go @@ -1,48 +1,53 @@ package v2 import ( + "context" "errors" "fmt" "net/http" "time" - "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/rs/zerolog/log" "github.com/semafind/semadb/cluster" "github.com/semafind/semadb/httpapi/middleware" + "github.com/semafind/semadb/httpapi/utils" "github.com/semafind/semadb/models" "github.com/vmihailenco/msgpack/v5" ) -func pongHandler(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "message": "pong from semadb", - }) -} - // --------------------------- +func handlePing(w http.ResponseWriter, r *http.Request) { + utils.Encode(w, http.StatusOK, map[string]string{"message": "pong from semadb"}) +} + type SemaDBHandlers struct { clusterNode *cluster.ClusterNode } // Requires middleware.AppHeaderMiddleware to be used -func SetupV2Handlers(clusterNode *cluster.ClusterNode, rgroup *gin.RouterGroup) { - rgroup.GET("/ping", pongHandler) +func SetupV2Handlers(clusterNode *cluster.ClusterNode) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/ping", handlePing) semaDBHandlers := &SemaDBHandlers{clusterNode: clusterNode} // https://stackoverflow.blog/2020/03/02/best-practices-for-rest-api-design/ - rgroup.POST("/collections", semaDBHandlers.CreateCollection) - rgroup.GET("/collections", semaDBHandlers.ListCollections) - colRoutes := rgroup.Group("/collections/:collectionId", semaDBHandlers.CollectionURIMiddleware()) - colRoutes.GET("", semaDBHandlers.GetCollection) - colRoutes.DELETE("", semaDBHandlers.DeleteCollection) + mux.HandleFunc("GET /collections", semaDBHandlers.HandleListCollections) + mux.HandleFunc("POST /collections", semaDBHandlers.HandleCreateCollection) + // --------------------------- + withCol := func(next http.HandlerFunc) http.Handler { + return semaDBHandlers.CollectionURIMiddleware(http.HandlerFunc(next)) + } + mux.Handle("GET /collections/{collectionId}", withCol(semaDBHandlers.HandleGetCollection)) + mux.Handle("DELETE /collections/{collectionId}", withCol(semaDBHandlers.HandleDeleteCollection)) // We're batching point requests for peformance reasons. Alternatively we // can provide points/:pointId endpoint in the future. - colRoutes.POST("/points", semaDBHandlers.InsertPoints) - colRoutes.PUT("/points", semaDBHandlers.UpdatePoints) - colRoutes.DELETE("/points", semaDBHandlers.DeletePoints) - colRoutes.POST("/points/search", semaDBHandlers.SearchPoints) + mux.Handle("POST /collections/{collectionId}/points", withCol(semaDBHandlers.HandleInsertPoints)) + mux.Handle("PUT /collections/{collectionId}/points", withCol(semaDBHandlers.HandleUpdatePoints)) + mux.Handle("DELETE /collections/{collectionId}/points", withCol(semaDBHandlers.HandleDeletePoints)) + mux.Handle("POST /collections/{collectionId}/points/search", withCol(semaDBHandlers.HandleSearchPoints)) + // --------------------------- + return mux } // --------------------------- @@ -52,18 +57,28 @@ type CreateCollectionRequest struct { IndexSchema models.IndexSchema `json:"indexSchema" binding:"required,dive"` } -func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { - var req CreateCollectionRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return +func (req CreateCollectionRequest) Validate() error { + if len(req.Id) < 3 || len(req.Id) > 24 { + return fmt.Errorf("id must be between 3 and 24 characters, got %d", len(req.Id)) } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - // --------------------------- - if err := req.IndexSchema.Validate(); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + // Check alpanum + for _, r := range req.Id { + if !((r >= 'a' && r <= 'z') || (r >= '0' && r <= '9')) { + return fmt.Errorf("id must be alphanumeric, got %s", req.Id) + } + } + // Validate index schema + return req.IndexSchema.Validate() +} + +func (sdbh *SemaDBHandlers) HandleCreateCollection(w http.ResponseWriter, r *http.Request) { + req, err := utils.DecodeValid[CreateCollectionRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } + appHeaders := middleware.GetAppHeaders(r.Context()) + userPlan := middleware.GetUserPlan(r.Context()) // --------------------------- vamanaCollection := models.Collection{ UserId: appHeaders.UserId, @@ -71,22 +86,20 @@ func (sdbh *SemaDBHandlers) CreateCollection(c *gin.Context) { Replicas: 1, Timestamp: time.Now().Unix(), CreatedAt: time.Now().Unix(), - UserPlan: c.MustGet("userPlan").(models.UserPlan), + UserPlan: userPlan, IndexSchema: req.IndexSchema, } log.Debug().Interface("collection", vamanaCollection).Msg("CreateCollection") // --------------------------- - err := sdbh.clusterNode.CreateCollection(vamanaCollection) - switch err { + switch err := sdbh.clusterNode.CreateCollection(vamanaCollection); err { case nil: - c.JSON(http.StatusOK, gin.H{"message": "collection created"}) + utils.Encode(w, http.StatusOK, map[string]string{"message": "collection created"}) case cluster.ErrQuotaReached: - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) case cluster.ErrExists: - c.JSON(http.StatusConflict, gin.H{"error": "collection exists"}) + utils.Encode(w, http.StatusConflict, map[string]string{"error": "collection exists"}) default: - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Str("id", vamanaCollection.Id).Msg("CreateCollection failed") } } @@ -99,13 +112,12 @@ type ListCollectionsResponse struct { Collections []ListCollectionItem `json:"collections"` } -func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) +func (sdbh *SemaDBHandlers) HandleListCollections(w http.ResponseWriter, r *http.Request) { + appHeaders := middleware.GetAppHeaders(r.Context()) // --------------------------- collections, err := sdbh.clusterNode.ListCollections(appHeaders.UserId) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) log.Error().Err(err).Msg("ListCollections failed") return } @@ -114,33 +126,32 @@ func (sdbh *SemaDBHandlers) ListCollections(c *gin.Context) { colItems[i] = ListCollectionItem{Id: col.Id} } resp := ListCollectionsResponse{Collections: colItems} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } // --------------------------- -type GetCollectionUri struct { - CollectionId string `uri:"collectionId" binding:"required,alphanum,min=3,max=24"` -} +type contextKey string -func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - var uri GetCollectionUri - if err := c.ShouldBindUri(&uri); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) +const collectionContextKey contextKey = "collection" + +func (sdbh *SemaDBHandlers) CollectionURIMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + collectionId := r.PathValue("collectionId") + if len(collectionId) < 3 || len(collectionId) > 24 { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "collectionId must be between 3 and 24 characters"}) return } - appHeaders := c.MustGet("appHeaders").(middleware.AppHeaders) - collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, uri.CollectionId) + appHeaders := middleware.GetAppHeaders(r.Context()) + collection, err := sdbh.clusterNode.GetCollection(appHeaders.UserId, collectionId) if err == cluster.ErrNotFound { - errMsg := fmt.Sprintf("collection %s not found", uri.CollectionId) - c.AbortWithStatusJSON(http.StatusNotFound, gin.H{"error": errMsg}) + errMsg := fmt.Sprintf("collection %s not found", collectionId) + utils.Encode(w, http.StatusNotFound, map[string]string{"error": errMsg}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -148,10 +159,11 @@ func (sdbh *SemaDBHandlers) CollectionURIMiddleware() gin.HandlerFunc { // This is because the user plan might change and we want the latest // active one rather than the one saved in the collection. This means // any downstream operation will use the latest user plan. - collection.UserPlan = c.MustGet("userPlan").(models.UserPlan) + collection.UserPlan = middleware.GetUserPlan(r.Context()) // --------------------------- - c.Set("collection", collection) - } + newReq := r.WithContext(context.WithValue(r.Context(), collectionContextKey, collection)) + next.ServeHTTP(w, newReq) + }) } // --------------------------- @@ -167,18 +179,17 @@ type GetCollectionResponse struct { Shards []ShardItem `json:"shards"` } -func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleGetCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- shards, err := sdbh.clusterNode.GetShardsInfo(collection) if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -191,26 +202,25 @@ func (sdbh *SemaDBHandlers) GetCollection(c *gin.Context) { IndexSchema: collection.IndexSchema, Shards: shardItems, } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- -func (sdbh *SemaDBHandlers) DeleteCollection(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeleteCollection(w http.ResponseWriter, r *http.Request) { // --------------------------- - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- deletedShardIds, err := sdbh.clusterNode.DeleteCollection(collection) if err != nil { - c.Error(err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } status := http.StatusOK if len(deletedShardIds) != len(collection.ShardIds) { status = http.StatusAccepted } - c.JSON(status, gin.H{"message": "collection deleted"}) + utils.Encode(w, status, map[string]string{"message": "collection deleted"}) } // --------------------------- @@ -219,47 +229,54 @@ type InsertPointsRequest struct { Points []models.PointAsMap `json:"points" binding:"required,max=10000"` } +func (req InsertPointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 10000 { + return fmt.Errorf("number of points must be between 1 and 10000, got %d", len(req.Points)) + } + return nil +} + type InsertPointsResponse struct { Message string `json:"message"` FailedRanges []cluster.FailedRange `json:"failedRanges"` } -func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleInsertPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req InsertPointsRequest startTime := time.Now() - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[InsertPointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } log.Debug().Str("bindTime", time.Since(startTime).String()).Msg("InsertPoints bind") // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) for i, point := range req.Points { if err := collection.IndexSchema.CheckCompatibleMap(point); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } pointId, err := point.ExtractIdField(true) if err != nil { errMsg := fmt.Sprintf("invalid id for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i] = models.Point{Id: pointId} pointData, err := msgpack.Marshal(point) if err != nil { errMsg := fmt.Sprintf("invalid point data for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(pointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(pointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = pointData @@ -268,23 +285,22 @@ func (sdbh *SemaDBHandlers) InsertPoints(c *gin.Context) { // Insert points returns a range of errors for failed shards failedRanges, err := sdbh.clusterNode.InsertPoints(collection, points) if errors.Is(err, cluster.ErrQuotaReached) { - c.JSON(http.StatusForbidden, gin.H{"error": "quota reached"}) + utils.Encode(w, http.StatusForbidden, map[string]string{"error": "quota reached"}) return } if errors.Is(err, cluster.ErrShardUnavailable) { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "one or more shards are unavailable"}) + utils.Encode(w, http.StatusServiceUnavailable, map[string]string{"error": "one or more shards are unavailable"}) return } if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := InsertPointsResponse{Message: "success", FailedRanges: failedRanges} if len(failedRanges) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- } @@ -294,21 +310,28 @@ type UpdatePointsRequest struct { Points []models.PointAsMap `json:"points" binding:"required,max=100"` } +func (req UpdatePointsRequest) Validate() error { + if len(req.Points) < 1 || len(req.Points) > 100 { + return fmt.Errorf("number of points must be between 1 and 100, got %d", len(req.Points)) + } + return nil +} + type UpdatePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleUpdatePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req UpdatePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[UpdatePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Convert request points into internal points, doing checks along the way points := make([]models.Point, len(req.Points)) @@ -316,23 +339,23 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { pointId, err := point.ExtractIdField(false) if err != nil { errMsg := fmt.Sprintf("invalid id for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if err := collection.IndexSchema.CheckCompatibleMap(point); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } points[i] = models.Point{Id: pointId} pointData, err := msgpack.Marshal(point) if err != nil { errMsg := fmt.Sprintf("invalid point data for point %d, %s", i, err.Error()) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } if len(pointData) > collection.UserPlan.MaxPointSize { errMsg := fmt.Sprintf("point %d exceeds maximum point size %d > %d", i, len(pointData), collection.UserPlan.MaxPointSize) - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errMsg}) return } points[i].Data = pointData @@ -341,15 +364,14 @@ func (sdbh *SemaDBHandlers) UpdatePoints(c *gin.Context) { // Update points returns a list of failed points failedPoints, err := sdbh.clusterNode.UpdatePoints(collection, points) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := UpdatePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -358,16 +380,28 @@ type DeletePointsRequest struct { Ids []string `json:"ids" binding:"required,max=100,dive,uuid"` } +func (req DeletePointsRequest) Validate() error { + if len(req.Ids) < 1 || len(req.Ids) > 100 { + return fmt.Errorf("number of ids must be between 1 and 100, got %d", len(req.Ids)) + } + for i, id := range req.Ids { + if _, err := uuid.Parse(id); err != nil { + return fmt.Errorf("invalid uuid at index %d", i) + } + } + return nil +} + type DeletePointsResponse struct { Message string `json:"message"` FailedPoints []cluster.FailedPoint `json:"failedPoints"` } -func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleDeletePoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req DeletePointsRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[DeletePointsRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- @@ -378,19 +412,18 @@ func (sdbh *SemaDBHandlers) DeletePoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- failedPoints, err := sdbh.clusterNode.DeletePoints(collection, pointIds) if err != nil { - c.Error(err) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } resp := DeletePointsResponse{Message: "success", FailedPoints: failedPoints} if len(failedPoints) > 0 { resp.Message = "partial success" } - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) } // --------------------------- @@ -399,11 +432,11 @@ type SearchPointsResponse struct { Points []models.PointAsMap `json:"points"` } -func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { +func (sdbh *SemaDBHandlers) HandleSearchPoints(w http.ResponseWriter, r *http.Request) { // --------------------------- - var req models.SearchRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + req, err := utils.DecodeValid[models.SearchRequest](r) + if err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // Default limit is 10 @@ -412,17 +445,17 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { } // --------------------------- // Get corresponding collection - collection := c.MustGet("collection").(models.Collection) + collection := r.Context().Value(collectionContextKey).(models.Collection) // --------------------------- // Validate query against schema, checks vector dimensions, query options etc. - if err := req.Query.Validate(collection.IndexSchema); err != nil { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + if err := req.Query.ValidateSchema(collection.IndexSchema); err != nil { + utils.Encode(w, http.StatusBadRequest, map[string]string{"error": err.Error()}) return } // --------------------------- points, err := sdbh.clusterNode.SearchPoints(collection, req) if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } results := make([]models.PointAsMap, len(points)) @@ -433,7 +466,7 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { if len(sp.Point.Data) > 0 { if err := msgpack.Unmarshal(sp.Point.Data, &pointData); err != nil { errMsg := fmt.Sprintf("could not decode point %s", sp.Point.Id.String()) - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": errMsg}) + utils.Encode(w, http.StatusInternalServerError, map[string]string{"error": errMsg}) return } } @@ -451,6 +484,6 @@ func (sdbh *SemaDBHandlers) SearchPoints(c *gin.Context) { results[i] = pointData } resp := SearchPointsResponse{Points: results} - c.JSON(http.StatusOK, resp) + utils.Encode(w, http.StatusOK, resp) // --------------------------- }