diff --git a/README.md b/README.md index e69de29..3719906 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,9 @@ +# Ratelimiter +This is Go implement algorithms about ratelimiter. + +### Installation: +The package can be installed as a Go module. + +``` +go get github.com/augustus281/ratelimiter +``` diff --git a/tokenbucket.go b/tokenbucket.go new file mode 100644 index 0000000..feea379 --- /dev/null +++ b/tokenbucket.go @@ -0,0 +1,51 @@ +package tokenbucket + +import ( + "sync" + "time" +) + +type TokenBucket struct { + sync.Mutex + tokens int // Current token count, start with a full bucket + maxTokens int // Maximum number of tokens the bucket hold + refillRate int // Rate at which tokens are added (tokens/second) + lastRefillTime time.Time // Last time we checked the token count +} + +func NewTokenBucket(maxTokens, refillRate int) *TokenBucket { + return &TokenBucket{ + tokens: maxTokens, + maxTokens: maxTokens, + refillRate: refillRate, + lastRefillTime: time.Now(), + } +} + +func (tb *TokenBucket) AddToken(tokens int) bool { + tb.Lock() + defer tb.Unlock() + + tb.refill() + if tokens < tb.tokens { + tb.tokens -= tokens + return true + } + + return false +} + +func (tb *TokenBucket) refill() { + now := time.Now() + duration := time.Since(tb.lastRefillTime) + tokenAdd := tb.tokens * int(duration.Seconds()) + tb.tokens = tb.min(tb.maxTokens, tb.tokens+tokenAdd) + tb.lastRefillTime = now +} + +func (tb *TokenBucket) min(a, b int) int { + if a <= b { + return a + } + return b +} diff --git a/tokenbucket_test.go b/tokenbucket_test.go new file mode 100644 index 0000000..265c636 --- /dev/null +++ b/tokenbucket_test.go @@ -0,0 +1,110 @@ +package tokenbucket + +import ( + "sync" + "testing" + "time" +) + +func TestTokenBucket_AddToken(t *testing.T) { + type fiels struct { + token int + maxToken int + refillRate int + lastRefillTime time.Time + Mutex sync.Mutex + } + + tests := []struct { + name string + fiels fiels + want bool + takedToken int + expectedToken int + }{ + { + "token available now", + fiels{ + token: 10, + maxToken: 10, + refillRate: 1, + lastRefillTime: time.Now(), + }, + true, + 9, + 1, + }, + { + "no token available now", + fiels{ + token: 0, + maxToken: 10, + refillRate: 1, + lastRefillTime: time.Now(), + }, + false, + 10, + 0, + }, + { + "tokens available after adjustment", + fiels{ + token: 1, + maxToken: 10, + refillRate: 1, + lastRefillTime: time.Now().Add(-1 * time.Second), + }, + true, + 1, + 1, + }, + { + "tokens do not refresh above capacity", + fiels{ + token: 10, + maxToken: 10, + refillRate: 1, + lastRefillTime: time.Now().Add(-2 * time.Minute), + }, + true, + 1, + 9, + }, + { + "refreshs 4 tokens", + fiels{ + token: 4, + maxToken: 10, + refillRate: 1, + lastRefillTime: time.Now().Add(-5 * time.Second), + }, + true, + 4, + 6, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := &TokenBucket{ + tokens: tt.fiels.token, + maxTokens: tt.fiels.maxToken, + refillRate: tt.fiels.refillRate, + lastRefillTime: tt.fiels.lastRefillTime, + Mutex: tt.fiels.Mutex, + } + + if got := b.AddToken(tt.takedToken); got != tt.want { + t.Errorf("TokenBucket.Add(%v) = %v, want %v", tt.takedToken, got, tt.want) + } + + if count := b.tokens; count != tt.expectedToken { + t.Errorf("Token count incorrect. Got %v, want %v", count, tt.expectedToken) + } + + if b.tokens > b.maxTokens { + t.Errorf("Max token is %v but current count is %v", b.maxTokens, b.tokens) + } + }) + } +}