Skip to content

Commit

Permalink
Remove gin framework from httpapi for more control
Browse files Browse the repository at this point in the history
The Gin framework is good but restricts validation, JSON inputs and middleware control. By going standard library, we will have more validation control, non-JSON input and standard library compatability.
  • Loading branch information
nuric authored Nov 30, 2024
1 parent 46b0046 commit 78e609b
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 333 deletions.
33 changes: 14 additions & 19 deletions httpapi/httpapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"net/http"
"strconv"

"github.com/gin-gonic/gin"
"github.com/prometheus/client_golang/prometheus"
"github.com/rs/zerolog/log"
"github.com/semafind/semadb/cluster"
Expand Down Expand Up @@ -34,39 +33,35 @@ type HttpApiConfig struct {

// ---------------------------

func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *gin.Engine {
router := gin.New()
func setupRouter(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) http.Handler {
// ---------------------------
var metrics *httpMetrics
if cfg.EnableMetrics && reg != nil {
metrics = setupAndListenMetrics(cfg, reg)
}
// ---------------------------
router.Use(ZerologLoggerMetrics(metrics), gin.Recovery())
mux := http.NewServeMux()
mux.Handle("/v1/", http.StripPrefix("/v1", httpv1.SetupV1Handlers(cnode)))
mux.Handle("/v2/", http.StripPrefix("/v2", httpv2.SetupV2Handlers(cnode)))
// ---------------------------
if len(cfg.ProxySecret) > 0 {
log.Info().Msg("ProxySecretMiddleware is enabled")
router.Use(ProxySecretMiddleware(cfg.ProxySecret))
}
var handler http.Handler = mux
handler = middleware.AppHeaderMiddleware(cfg.UserPlans, handler)
if cfg.WhiteListIPs == nil || (len(cfg.WhiteListIPs) == 1 && cfg.WhiteListIPs[0] == "*") {
log.Warn().Strs("whiteListIPs", cfg.WhiteListIPs).Msg("WhiteListIPMiddleware is disabled")
} else {
router.Use(WhiteListIPMiddleware(cfg.WhiteListIPs))
handler = WhiteListIPMiddleware(cfg.WhiteListIPs, handler)
}
if len(cfg.ProxySecret) > 0 {
log.Info().Msg("ProxySecretMiddleware is enabled")
handler = ProxySecretMiddleware(cfg.ProxySecret, handler)
}
handler = ZeroLoggerMetrics(metrics, handler)
handler = RecoverMiddleware(handler)
// ---------------------------
v1 := router.Group("/v1", middleware.AppHeaderMiddleware(cfg.UserPlans))
httpv1.SetupV1Handlers(cnode, v1)
// ---------------------------
v2 := router.Group("/v2", middleware.AppHeaderMiddleware(cfg.UserPlans))
httpv2.SetupV2Handlers(cnode, v2)
return router
return handler
}

func RunHTTPServer(cnode *cluster.ClusterNode, cfg HttpApiConfig, reg *prometheus.Registry) *http.Server {
// ---------------------------
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
// ---------------------------
server := &http.Server{
Addr: cfg.HttpHost + ":" + strconv.Itoa(cfg.HttpPort),
Expand Down
10 changes: 0 additions & 10 deletions httpapi/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ type httpMetrics struct {
requestCount *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
requestSize *prometheus.HistogramVec
responseSize *prometheus.HistogramVec
// ---------------------------
}

Expand Down Expand Up @@ -44,19 +43,10 @@ func setupAndListenMetrics(cfg HttpApiConfig, reg *prometheus.Registry) *httpMet
},
[]string{"code", "method", "handler"},
),
responseSize: prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_response_size_bytes",
Help: "HTTP request sizes in bytes.",
Buckets: []float64{0, 1 << 10, 1 << 15, 1 << 20},
},
[]string{"code", "method", "handler"},
),
}
reg.MustRegister(metrics.requestCount)
reg.MustRegister(metrics.requestDuration)
reg.MustRegister(metrics.requestSize)
reg.MustRegister(metrics.responseSize)
// ---------------------------
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{Registry: reg}))
Expand Down
131 changes: 54 additions & 77 deletions httpapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,102 +2,79 @@ package httpapi

import (
"net/http"
"regexp"
"runtime/debug"
"slices"
"strconv"
"strings"
"time"

"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"github.com/rs/zerolog/log"
"github.com/semafind/semadb/httpapi/middleware"
"github.com/semafind/semadb/httpapi/utils"
)

// ---------------------------

func ZerologLoggerMetrics(metrics *httpMetrics) gin.HandlerFunc {
return func(c *gin.Context) {
// ---------------------------
// Start timer
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
reqSize := c.Request.ContentLength
// ---------------------------
// Process request
c.Next()
// ---------------------------
// Stop timer and gather information
latency := time.Since(start)

method := c.Request.Method
statusCode := c.Writer.Status()
lastError := c.Errors.ByType(gin.ErrorTypePrivate).Last()

bodySize := c.Writer.Size()

if raw != "" {
path = path + "?" + raw
}
// ---------------------------
var logEvent *zerolog.Event
if statusCode == 500 || lastError != nil {
logEvent = log.Error()
} else {
logEvent = log.Info()
}
logEvent.Err(lastError).
Dur("latency", latency).
Str("clientIP", c.ClientIP()).
Str("remoteIP", c.RemoteIP()).
Str("method", method).Str("path", path).
Int("statusCode", statusCode).
Int64("requestSize", reqSize).
Int("bodySize", bodySize).
Str("path", path)
// Extract app headers if any
appH, ok := c.Keys["appHeaders"]
if ok {
appHeaders := appH.(middleware.AppHeaders)
// We are not logging the user ID for privacy reasons
logEvent = logEvent.Str("planId", appHeaders.PlanId)
}
logEvent.Msg("HTTPAPI")
// ---------------------------
// Zerolog based middleware for logging HTTP requests
func ZeroLoggerMetrics(metrics *httpMetrics, next http.Handler) http.Handler {
handler := hlog.AccessHandler(func(r *http.Request, status, size int, duration time.Duration) {
hlog.FromRequest(r).Info().
Str("method", r.Method).
Stringer("url", r.URL).
Int("status", status).
Int("size", size).
Dur("duration", duration).
Msg("")
if metrics != nil {
// Example handler names
// github.com/semafind/semadb/httpapi.(*SemaDBHandlers).ListCollections-fm
// github.com/semafind/semadb/httpapi.(*SemaDBHandlers).CreateCollection-fm
fullHName := c.HandlerName()
parts := strings.Split(fullHName, ".")
hname := parts[len(parts)-1][:len(parts[len(parts)-1])-3]
ssCode := strconv.Itoa(statusCode)
metrics.requestCount.WithLabelValues(ssCode, method, hname).Inc()
metrics.requestDuration.WithLabelValues(ssCode, method, hname).Observe(latency.Seconds())
metrics.requestSize.WithLabelValues(ssCode, method, hname).Observe(float64(reqSize))
metrics.responseSize.WithLabelValues(ssCode, method, hname).Observe(float64(bodySize))
// Canonicalize the URL by removing url parameters
// Replace anything of the form collections/mycol23 with collections/:id
re := regexp.MustCompile(`collections/[a-zA-Z0-9]+`)
canonical := re.ReplaceAll([]byte(r.URL.Path), []byte("collections/{collectionId}"))
hname := string(canonical)
ssCode := strconv.Itoa(status)
metrics.requestCount.WithLabelValues(ssCode, r.Method, hname).Inc()
metrics.requestDuration.WithLabelValues(ssCode, r.Method, hname).Observe(duration.Seconds())
metrics.requestSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(size))
// metrics.responseSize.WithLabelValues(ssCode, r.Method, hname).Observe(float64(bodySize))
}
}
})(next)
handler = hlog.NewHandler(log.Logger)(handler)
return handler
}

// ---------------------------

func ProxySecretMiddleware(secret string) gin.HandlerFunc {
func ProxySecretMiddleware(secret string, next http.Handler) http.Handler {
log.Debug().Str("proxySecret", secret).Msg("ProxySecretMiddleware")
return func(c *gin.Context) {
if c.GetHeader("X-Proxy-Secret") != secret {
c.AbortWithStatusJSON(http.StatusProxyAuthRequired, gin.H{"error": "forbidden"})
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("X-Proxy-Secret") != secret {
utils.Encode(w, http.StatusProxyAuthRequired, map[string]string{"error": "forbidden"})
return
}
}
next.ServeHTTP(w, r)
})
}

func WhiteListIPMiddleware(whitelist []string) gin.HandlerFunc {
func WhiteListIPMiddleware(whitelist []string, next http.Handler) http.Handler {
slices.Sort(whitelist)
return func(c *gin.Context) {
remoteIP := c.RemoteIP()
_, found := slices.BinarySearch(whitelist, remoteIP)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, found := slices.BinarySearch(whitelist, r.RemoteAddr)
if !found {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "forbidden"})
utils.Encode(w, http.StatusForbidden, map[string]string{"error": "forbidden"})
return
}
}
next.ServeHTTP(w, r)
})
}

func RecoverMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
log.Error().Interface("error", err).Msg("panic recovered")
log.Error().Str("stack", string(debug.Stack())).Msg("stack trace")
w.WriteHeader(http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
40 changes: 29 additions & 11 deletions httpapi/middleware/appheaders.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package middleware

import (
"context"
"fmt"
"net/http"

"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"github.com/semafind/semadb/httpapi/utils"
"github.com/semafind/semadb/models"
)

Expand All @@ -14,23 +15,40 @@ type AppHeaders struct {
PlanId string `header:"X-Plan-Id" binding:"required"`
}

func AppHeaderMiddleware(userPlans map[string]models.UserPlan) gin.HandlerFunc {
return func(c *gin.Context) {
var appHeaders AppHeaders
if err := c.ShouldBindHeader(&appHeaders); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
type contextKey string

const appHeadersKey contextKey = "appHeaders"
const userPlanKey contextKey = "userPlan"

func AppHeaderMiddleware(userPlans map[string]models.UserPlan, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
appHeaders := AppHeaders{
UserId: r.Header.Get("X-User-Id"),
PlanId: r.Header.Get("X-Plan-Id"),
}
if appHeaders.UserId == "" || appHeaders.PlanId == "" {
utils.Encode(w, http.StatusBadRequest, map[string]string{"error": "missing X-User-ID or X-Plan-Id headers"})
return
}
c.Set("appHeaders", appHeaders)
log.Debug().Interface("appHeaders", appHeaders).Msg("AppHeaderMiddleware")
// ---------------------------
newCtx := context.WithValue(r.Context(), appHeadersKey, appHeaders)
// Extract user plan
userPlan, ok := userPlans[appHeaders.PlanId]
if !ok {
errmsg := fmt.Sprintf("unknown user plan %s", appHeaders.PlanId)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errmsg})
utils.Encode(w, http.StatusBadRequest, map[string]string{"error": errmsg})
return
}
c.Set("userPlan", userPlan)
c.Next()
}
newCtx = context.WithValue(newCtx, userPlanKey, userPlan)
next.ServeHTTP(w, r.WithContext(newCtx))
})
}

func GetAppHeaders(ctx context.Context) AppHeaders {
return ctx.Value(appHeadersKey).(AppHeaders)
}

func GetUserPlan(ctx context.Context) models.UserPlan {
return ctx.Value(userPlanKey).(models.UserPlan)
}
50 changes: 50 additions & 0 deletions httpapi/utils/encdec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package utils

import (
"bytes"
"encoding/json"
"fmt"
"net/http"

"github.com/rs/zerolog/log"
)

// Validator is an object that can be validated.
type Validator interface {
Validate() error
}

// Encode writes the object to the response writer. It is usually used as the
// last step in a handler.
func Encode[T any](w http.ResponseWriter, status int, v T) {
w.Header().Set("Content-Type", "application/json")
// Write to buffer first to ensure the object is json encodable
// before writing to the response writer.
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(v); err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Error().Err(err).Msg("could not encode response")
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
return
}
w.WriteHeader(status)
w.Write(buf.Bytes())
}

// DecodeValid decodes the request body into the object and then validates it.
// Look at problems to see if there are any issues.
func DecodeValid[T Validator](r *http.Request) (T, error) {
var v T
if r.Header.Get("Content-Type") != "application/json" {
return v, fmt.Errorf("expected content type application/json, got %s", r.Header.Get("Content-Type"))
}
// ---------------------------
if err := json.NewDecoder(r.Body).Decode(&v); err != nil {
return v, fmt.Errorf("decode json: %w", err)
}
if err := v.Validate(); err != nil {
return v, fmt.Errorf("validation error: %w", err)
}
// ---------------------------
return v, nil
}
Loading

0 comments on commit 78e609b

Please sign in to comment.