Skip to content

Commit

Permalink
Add Next cleanup time header (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
aasifkhan7 authored Oct 26, 2024
1 parent 572a8d0 commit 7aa2e55
Showing 1 changed file with 31 additions and 16 deletions.
47 changes: 31 additions & 16 deletions internal/middleware/ratelimiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"server/config"
"server/internal/db"
"server/internal/server/utils"
mock "server/internal/tests/dbmocks"
Expand All @@ -18,6 +19,8 @@ import (

// RateLimiter middleware to limit requests based on a specified limit and duration
func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float64) http.Handler {
configValue := config.LoadConfig()
cronFrequencyInterval := configValue.Server.CronCleanupFrequency
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handleCors(w, r) {
return
Expand Down Expand Up @@ -78,28 +81,40 @@ func RateLimiter(client *db.DiceDB, next http.Handler, limit int64, window float
}
}

// Get the cron last cleanup run time
var lastCronCleanupTime int64
resp := client.Client.Get(ctx, utils.LastCronCleanupTimeUnixMs)
if resp.Err() != nil && !errors.Is(resp.Err(), dicedb.Nil) {
slog.Error("Failed to get last cron cleanup time for headers", slog.Any("err", resp.Err().Error()))
}

if resp.Val() != "" {
lastCronCleanupTime, err = strconv.ParseInt(resp.Val(), 10, 64)
if err != nil {
slog.Error("Error converting last cron cleanup time", "error", err)
}
secondsDifference, err := calculateNextCleanupTime(ctx, client, cronFrequencyInterval)
if err != nil {
slog.Error("Error calculating next cleanup time", "error", err)
}

addRateLimitHeaders(w, limit, limit-(requestCount+1), requestCount+1, currentWindow+int64(window),
lastCronCleanupTime)
secondsDifference)

slog.Info("Request processed", "count", requestCount+1)
next.ServeHTTP(w, r)
})
}

func calculateNextCleanupTime(ctx context.Context, client *db.DiceDB, cronFrequencyInterval time.Duration) (int64, error) {
var lastCronCleanupTime int64
resp := client.Client.Get(ctx, utils.LastCronCleanupTimeUnixMs)
if resp.Err() != nil && !errors.Is(resp.Err(), dicedb.Nil) {
return -1, resp.Err()
}

if resp.Val() != "" {
var err error
lastCronCleanupTime, err = strconv.ParseInt(resp.Val(), 10, 64) // directly assign here
if err != nil {
return -1, err
}
}

lastCleanupTime := time.UnixMilli(lastCronCleanupTime)
nextCleanupTime := lastCleanupTime.Add(cronFrequencyInterval)
timeDifference := nextCleanupTime.Sub(time.Now())
return int64(timeDifference.Seconds()), nil
}

func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, window float64) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handleCors(w, r) {
Expand Down Expand Up @@ -170,14 +185,14 @@ func MockRateLimiter(client *mock.DiceDBMock, next http.Handler, limit int64, wi
})
}

func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, cronLastCleanupTime int64) {
func addRateLimitHeaders(w http.ResponseWriter, limit, remaining, used, resetTime, secondsLeftForCleanup int64) {
w.Header().Set("x-ratelimit-limit", strconv.FormatInt(limit, 10))
w.Header().Set("x-ratelimit-remaining", strconv.FormatInt(remaining, 10))
w.Header().Set("x-ratelimit-used", strconv.FormatInt(used, 10))
w.Header().Set("x-ratelimit-reset", strconv.FormatInt(resetTime, 10))
w.Header().Set("x-last-cleanup-time", strconv.FormatInt(cronLastCleanupTime, 10))
w.Header().Set("x-next-cleanup-time", strconv.FormatInt(secondsLeftForCleanup, 10))

// Expose the rate limit headers to the client
w.Header().Set("Access-Control-Expose-Headers", "x-ratelimit-limit, x-ratelimit-remaining,"+
"x-ratelimit-used, x-ratelimit-reset, x-last-cleanup-time")
"x-ratelimit-used, x-ratelimit-reset, x-next-cleanup-time")
}

0 comments on commit 7aa2e55

Please sign in to comment.