diff --git a/config/config.go b/config/config.go index f36fc11..a056938 100644 --- a/config/config.go +++ b/config/config.go @@ -117,15 +117,6 @@ func getEnvArray(key string, fallback []string) []string { return fallback } -func getEnvBool(key string, fallback bool) bool { - if value, exists := os.LookupEnv(key); exists { - if boolValue, err := strconv.ParseBool(value); err == nil { - return boolValue - } - } - return fallback -} - // splitString splits a string by comma and returns a slice of strings func splitString(s string) []string { var array []string diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index f8d3c52..4d8a01b 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -15,83 +15,99 @@ import ( "time" "github.com/dicedb/dicedb-go" + "github.com/gin-gonic/gin" ) -// RateLimiter middleware to limit requests based on a specified limit and duration -func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler { - configValue := config.LoadConfig() - cronFrequencyInterval := configValue.Server.CronCleanupFrequency - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if handleCors(w, r) { - return - } +type ( + RateLimiterMiddleware struct { + client *db.DiceDB + limit int64 + window float64 + cronFrequencyInterval time.Duration + } +) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() +func NewRateLimiterMiddleware(client *db.DiceDB, limit int64, window float64) (rl *RateLimiterMiddleware) { + rl = &RateLimiterMiddleware{ + client: client, + limit: limit, + window: window, + cronFrequencyInterval: config.LoadConfig().Server.CronCleanupFrequency, + } + return +} - // Only apply rate limiting for specific paths (e.g., "/cli/") - if !strings.Contains(r.URL.Path, "/shell/exec/") { - next.ServeHTTP(w, r) - return - } +// RateLimiter middleware to limit requests based on a specified limit and duration +func (rl *RateLimiterMiddleware) Exec(c *gin.Context) { + if handleCors(c.Writer, c.Request) { + return + } - // Generate the rate limiting key based on the current window - currentWindow := time.Now().Unix() / int64(window) - key := fmt.Sprintf("request_count:%d", currentWindow) - slog.Debug("Created rate limiter key", slog.Any("key", key)) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - // Get the current request count for this window - val, err := client.Client.Get(ctx, key).Result() - if err != nil && !errors.Is(err, dicedb.Nil) { - slog.Error("Error fetching request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Only apply rate limiting for specific paths (e.g., "/cli/") + if !strings.Contains(c.Request.URL.Path, "/shell/exec/") { + c.Next() + return + } - // Parse the current request count or initialize to 0 - var requestCount int64 = 0 - if val != "" { - requestCount, err = strconv.ParseInt(val, 10, 64) - if err != nil { - slog.Error("Error converting request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - } + // Generate the rate limiting key based on the current window + currentWindow := time.Now().Unix() / int64(rl.window) + key := fmt.Sprintf("request_count:%d", currentWindow) + slog.Debug("Created rate limiter key", slog.Any("key", key)) + + // Get the current request count for this window + val, err := rl.client.Client.Get(ctx, key).Result() + if err != nil && !errors.Is(err, dicedb.Nil) { + slog.Error("Error fetching request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - // Check if the request count exceeds the limit - if requestCount >= limit { - slog.Warn("Request limit exceeded", "count", requestCount) - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), 0) - http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 + if val != "" { + requestCount, err = strconv.ParseInt(val, 10, 64) + if err != nil { + slog.Error("Error converting request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) return } + } - // Increment the request count - if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil { - slog.Error("Error incrementing request count", "error", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } + // Check if the request count exceeds the limit + if requestCount >= rl.limit { + slog.Warn("Request limit exceeded", "count", requestCount) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), 0) + http.Error(c.Writer, "429 - Too Many Requests", http.StatusTooManyRequests) + return + } - // Set the key expiry if it's newly created - if requestCount == 1 { - if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil { - slog.Error("Error setting expiry for request count", "error", err) - } - } + // Increment the request count + if requestCount, err = rl.client.Client.Incr(ctx, key).Result(); err != nil { + slog.Error("Error incrementing request count", "error", err) + http.Error(c.Writer, "Internal Server Error", http.StatusInternalServerError) + return + } - secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval) - if err != nil { - slog.Error("Error calculating next cleanup time", "error", err) + // Set the key expiry if it's newly created + if requestCount == 1 { + if err := rl.client.Client.Expire(ctx, key, time.Duration(rl.window)*time.Second).Err(); err != nil { + slog.Error("Error setting expiry for request count", "error", err) } + } - addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window), - secondsDifference) + secondsDifference, err := calculateNextCleanupTime(ctx, rl.client, rl.cronFrequencyInterval) + if err != nil { + slog.Error("Error calculating next cleanup time", "error", err) + } - slog.Info("Request processed", "count", requestCount+1) - next.ServeHTTP(w, r) - }) + addRateLimitHeaders(c.Writer, rl.limit, rl.limit-(requestCount+1), requestCount+1, currentWindow+int64(rl.window), + secondsDifference) + + slog.Info("Request processed", "count", requestCount+1) + c.Next() } func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) { diff --git a/internal/middleware/trailingslash.go b/internal/middleware/trailingslash.go index 80871ce..a4f93a0 100644 --- a/internal/middleware/trailingslash.go +++ b/internal/middleware/trailingslash.go @@ -3,19 +3,19 @@ package middleware import ( "net/http" "strings" + + "github.com/gin-gonic/gin" ) -func TrailingSlashMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/" && strings.HasSuffix(r.URL.Path, "/") { - newPath := strings.TrimSuffix(r.URL.Path, "/") - newURL := newPath - if r.URL.RawQuery != "" { - newURL += "?" + r.URL.RawQuery - } - http.Redirect(w, r, newURL, http.StatusMovedPermanently) - return +func TrailingSlashMiddleware(c *gin.Context) { + if c.Request.URL.Path != "/" && strings.HasSuffix(c.Request.URL.Path, "/") { + newPath := strings.TrimSuffix(c.Request.URL.Path, "/") + newURL := newPath + if c.Request.URL.RawQuery != "" { + newURL += "?" + c.Request.URL.RawQuery } - next.ServeHTTP(w, r) - }) + http.Redirect(c.Writer, c.Request, newURL, http.StatusMovedPermanently) + return + } + c.Next() } diff --git a/internal/server/http.go b/internal/server/http.go index b5209c2..b9c6d03 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -6,12 +6,12 @@ import ( "errors" "log/slog" "net/http" - "strings" "time" "server/internal/db" - "server/internal/middleware" util "server/util" + + "github.com/gin-gonic/gin" ) type HTTPServer struct { @@ -19,11 +19,6 @@ type HTTPServer struct { DiceClient *db.DiceDB } -type HandlerMux struct { - mux *http.ServeMux - rateLimiter func(http.ResponseWriter, *http.Request, http.Handler) -} - type HTTPResponse struct { Data interface{} `json:"data"` } @@ -43,26 +38,12 @@ func errorResponse(response string) string { return string(jsonResponse) } -func (cim *HandlerMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - middleware.TrailingSlashMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.ToLower(r.URL.Path) - cim.rateLimiter(w, r, cim.mux) - })).ServeHTTP(w, r) -} - -func NewHTTPServer(addr string, mux *http.ServeMux, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, +func NewHTTPServer(router *gin.Engine, diceDBAdminClient *db.DiceDB, diceClient *db.DiceDB, limit int64, window float64) *HTTPServer { - handlerMux := &HandlerMux{ - mux: mux, - rateLimiter: func(w http.ResponseWriter, r *http.Request, next http.Handler) { - middleware.RateLimiter(diceDBAdminClient, next, limit, window).ServeHTTP(w, r) - }, - } - return &HTTPServer{ httpServer: &http.Server{ - Addr: addr, - Handler: handlerMux, + Addr: ":8080", + Handler: router, ReadHeaderTimeout: 5 * time.Second, }, DiceClient: diceClient, diff --git a/main.go b/main.go index 9d60133..1106be4 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "os" "server/config" "server/internal/db" + "server/internal/middleware" "server/internal/server" "sync" @@ -56,8 +57,19 @@ func main() { c.Next() }) - httpServer := server.NewHTTPServer(":8080", nil, diceDBAdminClient, diceDBClient, configValue.Server.RequestLimitPerMin, - configValue.Server.RequestWindowSec) + router.Use(middleware.TrailingSlashMiddleware) + router.Use((middleware.NewRateLimiterMiddleware(diceDBAdminClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ).Exec)) + + httpServer := server.NewHTTPServer( + router, + diceDBAdminClient, + diceDBClient, + configValue.Server.RequestLimitPerMin, + configValue.Server.RequestWindowSec, + ) // Register routes router.GET("/health", gin.WrapF(httpServer.HealthCheck)) @@ -68,7 +80,7 @@ func main() { go func() { defer wg.Done() // Run the HTTP Server - if err := router.Run(":8080"); err != nil { + if err := httpServer.Run(context.Background()); err != nil { slog.Error("server failed: %v\n", slog.Any("err", err)) diceDBAdminClient.CloseDiceDB() cancel()