diff --git a/sliding_window_counter.go b/sliding_window_counter.go new file mode 100644 index 0000000..d623c87 --- /dev/null +++ b/sliding_window_counter.go @@ -0,0 +1,52 @@ +package ratelimiter + +import ( + "sync" + "time" +) + +type SlidingWindowCounter struct { + windowSize time.Duration + maxRequests int + currentWindow int64 + requestCount int + previousCount int + mu sync.Mutex +} + +func NewSlidingWindowCounter(windowSize time.Duration, maxRequests int) *SlidingWindowCounter { + return &SlidingWindowCounter{ + windowSize: windowSize, + maxRequests: maxRequests, + currentWindow: time.Now().Unix() / int64(windowSize.Seconds()), + requestCount: 0, + previousCount: 0, + } +} + +func (swc *SlidingWindowCounter) AllowRequest() bool { + swc.mu.Lock() + defer swc.mu.Unlock() + + now := time.Now().Unix() + window := now / int64(swc.windowSize.Seconds()) + + // If we've moved to a new window, update the counts + if window != swc.currentWindow { + swc.previousCount = swc.requestCount + swc.requestCount = 0 + swc.currentWindow = window + } + + // Calculate the weighted request count + windowElapsed := float64(now%int64(swc.windowSize.Seconds())) / float64(swc.windowSize.Seconds()) + threshold := float64(swc.previousCount)*(1-windowElapsed) + float64(swc.requestCount) + + // Check if we're within the limit + if threshold < float64(swc.maxRequests) { + swc.requestCount++ + return true + } + + return false +} diff --git a/sliding_window_counter_test.go b/sliding_window_counter_test.go new file mode 100644 index 0000000..8f69f21 --- /dev/null +++ b/sliding_window_counter_test.go @@ -0,0 +1,62 @@ +package ratelimiter + +import ( + "testing" + "time" +) + +func newSlidingWindowCounter(windowSize time.Duration, maxRequests int) *SlidingWindowCounter { + return NewSlidingWindowCounter(windowSize, maxRequests) +} + +func TestSlidingWindowCounter_AllowRequest(t *testing.T) { + tests := []struct { + windowSize time.Duration + maxRequests int + requests int + expectAllowed bool + }{ + {time.Second * 10, 5, 5, true}, // within limit + {time.Second * 10, 5, 6, false}, // exceeding limit + {time.Second * 10, 5, 10, false}, // far exceeding limit + {time.Second * 5, 2, 2, true}, // within smaller window + {time.Second * 5, 2, 3, false}, // exceeding limit in smaller window + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + swc := newSlidingWindowCounter(tt.windowSize, tt.maxRequests) + + for i := 0; i < tt.requests; i++ { + allowed := swc.AllowRequest() + if i < tt.maxRequests && !allowed { + t.Errorf("Request %d was not allowed, but it should be", i) + } + if i >= tt.maxRequests && allowed { + t.Errorf("Request %d was allowed, but it should not be", i) + } + } + }) + } +} + +func TestSlidingWindowCounter_WindowExpiration(t *testing.T) { + windowSize := time.Second * 2 + maxRequests := 2 + + swc := newSlidingWindowCounter(windowSize, maxRequests) + + if !swc.AllowRequest() { + t.Errorf("First request should be allowed") + } + + if !swc.AllowRequest() { + t.Errorf("Second request should be allowed") + } + + time.Sleep(windowSize + time.Second) + + if swc.AllowRequest() { + t.Errorf("Request after window expiration should not be allowed") + } +}