From 94d75e3ee85dfbea0741cb2e4d082831f6b4882b Mon Sep 17 00:00:00 2001 From: Gaurav Sarma Date: Wed, 11 Dec 2024 19:32:11 +0800 Subject: [PATCH 1/3] Integrate Gin router with HTTP server --- internal/middleware/ratelimiter.go | 140 ++++++++++++++++------------- internal/server/http.go | 14 ++- main.go | 10 ++- 3 files changed, 92 insertions(+), 72 deletions(-) diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index f8d3c52..bda9bd1 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -15,83 +15,100 @@ import ( "time" "github.com/dicedb/dicedb-go" + "github.com/gin-gonic/gin" ) -// RateLimiter middleware to limit requests based on a specified limit and duration -func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { - configValue := config.LoadConfig() - cronFrequencyInterval := configValue.Server.CronCleanupFrequency - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if handleCors(w, r) { - return - } +type ( + RateLimiterMiddleware struct { + client *db.DiceDB + limit int64 + window float64 + conf *config.Config + cronFrequencyInterval time.Duration + } +) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() +func NewRateLimiterMiddleware(client *db.DiceDB, limit int64, window float64) (rl *RateLimiterMiddleware) { + rl = &RateLimiterMiddleware{ + client: client, + limit: limit, + window: window, + cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency, + } + return +} - // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/shell/exec/") { - next.ServeHTTP(w, r) - return - } +// RateLimiter middleware to limit requests based on a specified limit and duration +func (rl *RateLimiterMiddleware) Exec(c *gin.Context) { + if handleCors(c.Writer, c.Request) { + return + } - // Generate the rate limiting key based on the current window - currentWindow := time.Now().Unix() / int64(window) - key := fmt.Sprintf("request_count:%d", currentWindow) - slog.Debug("Created rate limiter key", slog.Any("key", key)) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - // Get the current request count for this window - val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dicedb.Nil) { - slog.Error("Error fetching request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Only apply rate limiting for specific paths (e.g., "/cli/") + if !strings.Contains(c.Request.URL.Path, "/shell/exec/") { + c.Next() + return + } - // Parse the current request count or initialize to 0 - var requestCount int64 = 0 - if val != "" { - requestCount, err = strconv.ParseInt(val, 10, 64) - if err != nil { - slog.Error("Error converting request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - } + // Generate the rate limiting key based on the current window + currentWindow := time.Now().Unix() / int64(rl.window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Debug("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window + val, err := rl.client.Client.Get(ctx, key).Result() + if err != nil && !errors.Is(err, dicedb.Nil) { + slog.Error("Error fetching request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - // Check if the request count exceeds the limit - if requestCount >= limit { - slog.Warn("Request limit exceeded", "count", requestCount) - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) - http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) return } + } - // Increment the request count - if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil { - slog.Error("Error incrementing request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Check if the request count exceeds the limit + if requestCount >= rl.limit { + slog.Warn("Request limit exceeded", "count", requestCount) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0) + http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests) + return + } - // Set the key expiry if it's newly created - if requestCount == 1 { - if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil { - slog.Error("Error setting expiry for request count", "error", err) - } - } + // Increment the request count + if requestCount, err = rl.client.Client.Incr(ctx, key).Result(); err != nil { + slog.Error("Error incrementing request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval) - if err != nil { - slog.Error("Error calculating next cleanup time", "error", err) + // Set the key expiry if it's newly created + if requestCount == 1 { + if err := rl.client.Client.Expire(ctx, key, time.Duration(rl.window)*time.Second).Err(); err != nil { + slog.Error("Error setting expiry for request count", "error", err) } + } - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), - secondsDifference) + secondsDifference, err := calculateNextCleanupTime(ctx, rl.client, rl.cronFrequencyInterval) + if err != nil { + slog.Error("Error calculating next cleanup time", "error", err) + } - slog.Info("Request processed", "count", requestCount+1) - next.ServeHTTP(w, r) - }) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), + secondsDifference) + + slog.Info("Request processed", "count", requestCount+1) + c.Next() } func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) { @@ -186,6 +203,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { + w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) diff --git a/internal/server/http.go b/internal/server/http.go index b5209c2..2553a97 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -12,6 +12,8 @@ import ( "server/internal/db" "server/internal/middleware" util "server/util" + + "github.com/gin-gonic/gin" ) type HTTPServer struct { @@ -50,19 +52,13 @@ func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { })).ServeHTTP(w, r) } -func NewHTTPServer(addr string, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, +func NewHTTPServer(router *gin.Engine, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { - handlerMux := &HandlerMux{ - mux: mux, - rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) { - middleware.RateLimiter(diceDBAdminClient, next, limit, window).ServeHTTP(w, r) - }, - } return &HTTPServer{ httpServer: &http.Server{ - Addr: addr, - Handler: handlerMux, + Addr: ":8080", + Handler: router, ReadHeaderTimeout: 5 * time.Second, }, DiceClient: diceClient, diff --git a/main.go b/main.go index 9d60133..f36bde0 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "server/config" "server/internal/db" + "server/internal/middleware" "server/internal/server" "sync" @@ -56,7 +57,12 @@ func main() { c.Next() }) - httpServer := server.NewHTTPServer(":8080", nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, + router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ).Exec)) + + httpServer := server.NewHTTPServer(router, nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, configValue.Server.RequestWindowSec) // Register routes @@ -68,7 +74,7 @@ func main() { go func() { defer wg.Done() // Run the HTTP Server - if err := router.Run(":8080"); err != nil { + if err := httpServer.Run(context.Background()); err != nil { slog.Error("server failed: %v\n", slog.Any("err", err)) diceDBAdminClient.CloseDiceDB() cancel() From 2e45060357a0e7f302b5aa74a4e886488507216b Mon Sep 17 00:00:00 2001 From: Gaurav Sarma Date: Wed, 11 Dec 2024 19:48:29 +0800 Subject: [PATCH 2/3] Modified the TrailingSlashMiddleware --- internal/middleware/trailingslash.go | 24 ++++++++++++------------ internal/server/http.go | 16 +--------------- main.go | 10 ++++++++-- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 80871ce..a4f93a0 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -3,19 +3,19 @@ package middleware import ( "net/http" "strings" + + "github.com/gin-gonic/gin" ) -func TrailingSlashMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - newPath := strings.TrimSuffix(r.URL.Path, "/") - newURL := newPath - if r.URL.RawQuery != "" { - newURL += "?" + r.URL.RawQuery - } - http.Redirect(w, r, newURL, http.StatusMovedPermanently) - return +func TrailingSlashMiddleware(c *gin.Context) { + if c.Request.URL.Path != "/" && strings.HasSuffix(c.Request.URL.Path, "/") { + newPath := strings.TrimSuffix(c.Request.URL.Path, "/") + newURL := newPath + if c.Request.URL.RawQuery != "" { + newURL += "?" + c.Request.URL.RawQuery } - next.ServeHTTP(w, r) - }) + http.Redirect(c.Writer, c.Request, newURL, http.StatusMovedPermanently) + return + } + c.Next() } diff --git a/internal/server/http.go b/internal/server/http.go index 2553a97..f37a8f0 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,11 +6,9 @@ import ( "errors" "log/slog" "net/http" - "strings" "time" "server/internal/db" - "server/internal/middleware" util "server/util" "github.com/gin-gonic/gin" @@ -21,11 +19,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -type HandlerMux struct { - mux *http.ServeMux - rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) -} - type HTTPResponse struct { Data interface{} `json:"data"` } @@ -45,14 +38,7 @@ func errorResponse(response string) string { return string(jsonResponse) } -func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.ToLower(r.URL.Path) - cim.rateLimiter(w, r, cim.mux) - })).ServeHTTP(w, r) -} - -func NewHTTPServer(router *gin.Engine, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, +func NewHTTPServer(router *gin.Engine, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { return &HTTPServer{ diff --git a/main.go b/main.go index f36bde0..1106be4 100644 --- a/main.go +++ b/main.go @@ -57,13 +57,19 @@ func main() { c.Next() }) + router.Use(middleware.TrailingSlashMiddleware) router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient, configValue.Server.RequestLimitPerMin, configValue.Server.RequestWindowSec, ).Exec)) - httpServer := server.NewHTTPServer(router, nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, - configValue.Server.RequestWindowSec) + httpServer := server.NewHTTPServer( + router, + diceDBAdminClient, + diceDBClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ) // Register routes router.GET("/health", gin.WrapF(httpServer.HealthCheck)) From 7a99484a4ab239ee4e309603e7137246a6da8c22 Mon Sep 17 00:00:00 2001 From: pshubham Date: Wed, 11 Dec 2024 18:03:50 +0530 Subject: [PATCH 3/3] Fix lint errors --- config/config.go | 9 --------- internal/middleware/ratelimiter.go | 2 -- internal/server/http.go | 1 - 3 files changed, 12 deletions(-) diff --git a/config/config.go b/config/config.go index f36fc11..a056938 100644 --- a/config/config.go +++ b/config/config.go @@ -117,15 +117,6 @@ func getEnvArray(key string, fallback []string) []string { return fallback } -func getEnvBool(key string, fallback bool) bool { - if value, exists := os.LookupEnv(key); exists { - if boolValue, err := strconv.ParseBool(value); err == nil { - return boolValue - } - } - return fallback -} - // splitString splits a string by comma and returns a slice of strings func splitString(s string) []string { var array []string diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index bda9bd1..4d8a01b 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -23,7 +23,6 @@ type ( client *db.DiceDB limit int64 window float64 - conf *config.Config cronFrequencyInterval time.Duration } ) @@ -203,7 +202,6 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) { - w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) diff --git a/internal/server/http.go b/internal/server/http.go index f37a8f0..b9c6d03 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -40,7 +40,6 @@ func errorResponse(response string) string { func NewHTTPServer(router *gin.Engine, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { - return &HTTPServer{ httpServer: &http.Server{ Addr: ":8080",