Skip to content

Commit

Permalink
Merge pull request #7 from augustus281/DEV
Browse files Browse the repository at this point in the history
[algo]: sliding window counter
  • Loading branch information
augustus281 authored Aug 23, 2024
2 parents 8a33247 + 41b2a14 commit 1e7e136
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
52 changes: 52 additions & 0 deletions sliding_window_counter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package ratelimiter

import (
"sync"
"time"
)

type SlidingWindowCounter struct {
windowSize time.Duration
maxRequests int
currentWindow int64
requestCount int
previousCount int
mu sync.Mutex
}

func NewSlidingWindowCounter(windowSize time.Duration, maxRequests int) *SlidingWindowCounter {
return &SlidingWindowCounter{
windowSize: windowSize,
maxRequests: maxRequests,
currentWindow: time.Now().Unix() / int64(windowSize.Seconds()),
requestCount: 0,
previousCount: 0,
}
}

func (swc *SlidingWindowCounter) AllowRequest() bool {
swc.mu.Lock()
defer swc.mu.Unlock()

now := time.Now().Unix()
window := now / int64(swc.windowSize.Seconds())

// If we've moved to a new window, update the counts
if window != swc.currentWindow {
swc.previousCount = swc.requestCount
swc.requestCount = 0
swc.currentWindow = window
}

// Calculate the weighted request count
windowElapsed := float64(now%int64(swc.windowSize.Seconds())) / float64(swc.windowSize.Seconds())
threshold := float64(swc.previousCount)*(1-windowElapsed) + float64(swc.requestCount)

// Check if we're within the limit
if threshold < float64(swc.maxRequests) {
swc.requestCount++
return true
}

return false
}
62 changes: 62 additions & 0 deletions sliding_window_counter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package ratelimiter

import (
"testing"
"time"
)

func newSlidingWindowCounter(windowSize time.Duration, maxRequests int) *SlidingWindowCounter {
return NewSlidingWindowCounter(windowSize, maxRequests)
}

func TestSlidingWindowCounter_AllowRequest(t *testing.T) {
tests := []struct {
windowSize time.Duration
maxRequests int
requests int
expectAllowed bool
}{
{time.Second * 10, 5, 5, true}, // within limit
{time.Second * 10, 5, 6, false}, // exceeding limit
{time.Second * 10, 5, 10, false}, // far exceeding limit
{time.Second * 5, 2, 2, true}, // within smaller window
{time.Second * 5, 2, 3, false}, // exceeding limit in smaller window
}

for _, tt := range tests {
t.Run("", func(t *testing.T) {
swc := newSlidingWindowCounter(tt.windowSize, tt.maxRequests)

for i := 0; i < tt.requests; i++ {
allowed := swc.AllowRequest()
if i < tt.maxRequests && !allowed {
t.Errorf("Request %d was not allowed, but it should be", i)
}
if i >= tt.maxRequests && allowed {
t.Errorf("Request %d was allowed, but it should not be", i)
}
}
})
}
}

func TestSlidingWindowCounter_WindowExpiration(t *testing.T) {
windowSize := time.Second * 2
maxRequests := 2

swc := newSlidingWindowCounter(windowSize, maxRequests)

if !swc.AllowRequest() {
t.Errorf("First request should be allowed")
}

if !swc.AllowRequest() {
t.Errorf("Second request should be allowed")
}

time.Sleep(windowSize + time.Second)

if swc.AllowRequest() {
t.Errorf("Request after window expiration should not be allowed")
}
}

0 comments on commit 1e7e136

Please sign in to comment.