Skip to content

Commit

Permalink
Adding rate limiting headers to responses (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushsatyam146 authored Oct 9, 2024
1 parent 0005207 commit 2b4aa11
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
32 changes: 23 additions & 9 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,27 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

// Only apply rate limiting for specific paths (e.g., "/cli/")
if !strings.Contains(r.URL.Path, "/shell/exec/") {
next.ServeHTTP(w, r)
return
}

// Get the current time window as a unique key
// Generate the rate limiting key based on the current window
currentWindow := time.Now().Unix() / int64(window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Info("Created rate limiter key", slog.Any("key", key))
slog.Debug("Created rate limiter key", slog.Any("key", key))

// Fetch the current request count
// Get the current request count for this window
val, err := client.Client.Get(ctx, key).Result()
if err != nil && !errors.Is(err, dicedb.Nil) {
slog.Error("Error fetching request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Initialize request count
requestCount := int64(0)
// Parse the current request count or initialize to 0
var requestCount int64 = 0
if val != "" {
requestCount, err = strconv.ParseInt(val, 10, 64)
if err != nil {
Expand All @@ -55,26 +56,29 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float
}

// Check if the request count exceeds the limit
if requestCount > limit {
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window))
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}

// Increment the request count
if _, err := client.Client.Incr(ctx, key).Result(); err != nil {
if requestCount, err = client.Client.Incr(ctx, key).Result(); err != nil {
slog.Error("Error incrementing request count", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}

// Set the key expiry if it's newly created
if requestCount == 0 {
if requestCount == 1 {
if err := client.Client.Expire(ctx, key, time.Duration(window)*time.Second).Err(); err != nil {
slog.Error("Error setting expiry for request count", "error", err)
}
}

addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window))

slog.Info("Request processed", "count", requestCount+1)
next.ServeHTTP(w, r)
})
Expand All @@ -98,7 +102,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
// Generate the rate limiting key based on the current window
currentWindow := time.Now().Unix() / int64(window)
key := fmt.Sprintf("request_count:%d", currentWindow)
slog.Info("Created rate limiter key", slog.Any("key", key))
slog.Debug("Created rate limiter key", slog.Any("key", key))

// Get the current request count for this window from the mock DB
val, err := client.Get(ctx, key)
Expand All @@ -122,6 +126,7 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
// Check if the request limit has been exceeded
if requestCount >= limit {
slog.Warn("Request limit exceeded", "count", requestCount)
addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window))
http.Error(w, "429 - Too Many Requests", http.StatusTooManyRequests)
return
}
Expand All @@ -142,7 +147,16 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
}
}

addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window))

slog.Info("Request processed", "count", requestCount)
next.ServeHTTP(w, r)
})
}

func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime int64) {
w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10))
w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10))
w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10))
w.Header().Set("x-ratelimit-reset", strconv.FormatInt(resetTime, 10))
}
15 changes: 15 additions & 0 deletions internal/tests/integration/ratelimiter_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,18 @@ func TestRateLimiterExceedsLimit(t *testing.T) {
require.Equal(t, http.StatusTooManyRequests, w.Code)
require.Contains(t, w.Body.String(), "429 - Too Many Requests")
}

func TestRateLimitHeadersSet(t *testing.T) {
configValue := config.LoadConfig()
limit := configValue.RequestLimitPerMin
window := configValue.RequestWindowSec

w, r, rateLimiter := util.SetupRateLimiter(limit, window)

rateLimiter.ServeHTTP(w, r)

require.NotEmpty(t, w.Header().Get("x-ratelimit-limit"), "x-ratelimit-limit should be set")
require.NotEmpty(t, w.Header().Get("x-ratelimit-remaining"), "x-ratelimit-remaining should be set")
require.NotEmpty(t, w.Header().Get("x-ratelimit-used"), "x-ratelimit-used should be set")
require.NotEmpty(t, w.Header().Get("x-ratelimit-reset"), "x-ratelimit-reset should be set")
}

0 comments on commit 2b4aa11

Please sign in to comment.