diff --git a/internal/middleware/ratelimiter.go b/internal/middleware/ratelimiter.go index c852a26..9cfb4da 100644 --- a/internal/middleware/ratelimiter.go +++ b/internal/middleware/ratelimiter.go @@ -22,17 +22,18 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() + // Only apply rate limiting for specific paths (e.g., "/cli/") if !strings.Contains(r.URL.Path, "/shell/exec/") { next.ServeHTTP(w, r) return } - // Get the current time window as a unique key + // Generate the rate limiting key based on the current window currentWindow := time.Now().Unix() / int64(window) key := fmt.Sprintf("request_count:%d", currentWindow) - slog.Info("Created rate limiter key", slog.Any("key", key)) + slog.Debug("Created rate limiter key", slog.Any("key", key)) - // Fetch the current request count + // Get the current request count for this window val, err := client.Client.Get(ctx, key).Result() if err != nil && !errors.Is(err, dicedb.Nil) { slog.Error("Error fetching request count", "error", err) @@ -40,8 +41,8 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float return } - // Initialize request count - requestCount := int64(0) + // Parse the current request count or initialize to 0 + var requestCount int64 = 0 if val != "" { requestCount, err = strconv.ParseInt(val, 10, 64) if err != nil { @@ -52,26 +53,29 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float } // Check if the request count exceeds the limit - if requestCount > limit { + if requestCount >= limit { slog.Warn("Request limit exceeded", "count", requestCount) + addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window)) http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) return } // Increment the request count - if _, err := client.Client.Incr(ctx, key).Result(); err != nil { + if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil { slog.Error("Error incrementing request count", "error", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Set the key expiry if it's newly created - if requestCount == 0 { + if requestCount == 1 { if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil { slog.Error("Error setting expiry for request count", "error", err) } } + addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window)) + slog.Info("Request processed", "count", requestCount+1) next.ServeHTTP(w, r) }) @@ -92,7 +96,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi // Generate the rate limiting key based on the current window currentWindow := time.Now().Unix() / int64(window) key := fmt.Sprintf("request_count:%d", currentWindow) - slog.Info("Created rate limiter key", slog.Any("key", key)) + slog.Debug("Created rate limiter key", slog.Any("key", key)) // Get the current request count for this window from the mock DB val, err := client.Get(ctx, key) @@ -116,6 +120,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi // Check if the request limit has been exceeded if requestCount >= limit { slog.Warn("Request limit exceeded", "count", requestCount) + addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window)) http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests) return } @@ -136,7 +141,16 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi } } + addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window)) + slog.Info("Request processed", "count", requestCount) next.ServeHTTP(w, r) }) } + +func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime int64) { + w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10)) + w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10)) + w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10)) + w.Header().Set("x-ratelimit-reset", strconv.FormatInt(resetTime, 10)) +} diff --git a/internal/tests/integration/ratelimiter_integration_test.go b/internal/tests/integration/ratelimiter_integration_test.go index 7faca79..9d82a41 100644 --- a/internal/tests/integration/ratelimiter_integration_test.go +++ b/internal/tests/integration/ratelimiter_integration_test.go @@ -40,3 +40,18 @@ func TestRateLimiterExceedsLimit(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, w.Code) require.Contains(t, w.Body.String(), "429 - Too Many Requests") } + +func TestRateLimitHeadersSet(t *testing.T) { + configValue := config.LoadConfig() + limit := configValue.RequestLimitPerMin + window := configValue.RequestWindowSec + + w, r, rateLimiter := util.SetupRateLimiter(limit, window) + + rateLimiter.ServeHTTP(w, r) + + require.NotEmpty(t, w.Header().Get("x-ratelimit-limit"), "x-ratelimit-limit should be set") + require.NotEmpty(t, w.Header().Get("x-ratelimit-remaining"), "x-ratelimit-remaining should be set") + require.NotEmpty(t, w.Header().Get("x-ratelimit-used"), "x-ratelimit-used should be set") + require.NotEmpty(t, w.Header().Get("x-ratelimit-reset"), "x-ratelimit-reset should be set") +}