Skip to content

Commit

Permalink
Add Peek() method to inspect limiter Context without modifying it
Browse files Browse the repository at this point in the history
  • Loading branch information
dougnukem committed Feb 10, 2016
1 parent ec1e1c3 commit f8bea7e
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 24 deletions.
5 changes: 5 additions & 0 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ func NewLimiter(store Store, rate Rate) *Limiter {
func (l *Limiter) Get(key string) (Context, error) {
return l.Store.Get(key, l.Rate)
}

// Peek returns the limit for identifier without impacting accounting
func (l *Limiter) Peek(key string) (Context, error) {
return l.Store.Peek(key, l.Rate)
}
55 changes: 33 additions & 22 deletions limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@ package limiter

import (
"math"
"math/rand"
"testing"
"time"

"github.com/garyburd/redigo/redis"
"github.com/stretchr/testify/assert"
)

func init() {
rand.Seed(time.Now().UnixNano())
}

var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")

func RandStringRunes(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letterRunes[rand.Intn(len(letterRunes))]
}
return string(b)
}

// TestLimiterMemory tests Limiter with memory store.
func TestLimiterMemory(t *testing.T) {
rate, err := NewRateFromFormatted("3-M")
Expand All @@ -19,50 +34,45 @@ func TestLimiterMemory(t *testing.T) {
CleanUpInterval: 30 * time.Second,
})

limiter := NewLimiter(store, rate)

i := 1
for i <= 5 {
ctx, err := limiter.Get("boo")
assert.Nil(t, err)

if i <= 3 {
assert.Equal(t, int64(3), ctx.Limit)
assert.Equal(t, int64(3-i), ctx.Remaining)
assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60)
} else {
assert.Equal(t, int64(3), ctx.Limit)
assert.True(t, ctx.Remaining == 0)
assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60)
}

i++
}
testLimiter(t, store, rate)
}

// TestLimiterRedis tests Limiter with Redis store.
func TestLimiterRedis(t *testing.T) {
rate, err := NewRateFromFormatted("3-M")
assert.Nil(t, err)

randPrefix := RandStringRunes(10)
store, err := NewRedisStoreWithOptions(
newRedisPool(),
StoreOptions{Prefix: "limitertests:redis", MaxRetry: 3})
StoreOptions{Prefix: "limitertests:redis_" + randPrefix, MaxRetry: 3})

assert.Nil(t, err)

testLimiter(t, store, rate)
}

func testLimiter(t *testing.T, store Store, rate Rate) {
limiter := NewLimiter(store, rate)

i := 1
for i <= 5 {
if i <= 3 {
ctx, err := limiter.Peek("boo")
assert.NoError(t, err)
assert.Equal(t, int64(3-(i-1)), ctx.Remaining)
}

ctx, err := limiter.Get("boo")
assert.Nil(t, err)
assert.NoError(t, err)

if i <= 3 {
assert.Equal(t, int64(3), ctx.Limit)
assert.Equal(t, int64(3-i), ctx.Remaining)
assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60)

ctx, err := limiter.Peek("boo")
assert.NoError(t, err)
assert.Equal(t, int64(3-i), ctx.Remaining)
} else {
assert.Equal(t, int64(3), ctx.Limit)
assert.True(t, ctx.Remaining == 0)
Expand All @@ -71,6 +81,7 @@ func TestLimiterRedis(t *testing.T) {

i++
}

}

// -----------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "time"
// Store is the common interface for limiter stores.
type Store interface {
Get(key string, rate Rate) (Context, error)
Peek(key string, rate Rate) (Context, error)
}

// StoreOptions are options for store.
Expand Down
34 changes: 32 additions & 2 deletions store_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,47 @@ func (s *MemoryStore) Get(key string, rate Rate) (Context, error) {
return ctx, err
}

return s.getContextFromState(now, rate, item.Expiration, count), nil
}

// Peek implement Store.Peek() method.
func (s *MemoryStore) Peek(key string, rate Rate) (Context, error) {
ctx := Context{}
key = fmt.Sprintf("%s:%s", s.Prefix, key)
item, found := s.Cache.Items()[key]
ms := int64(time.Millisecond)
now := time.Now()

if !found || item.Expired() {
// new or expired should show what the values "would" be but not set cache state
return Context{
Limit: rate.Limit,
Remaining: rate.Limit,
Reset: (now.UnixNano()/ms + int64(rate.Period)/ms) / 1000,
Reached: false,
}, nil
}

count, ok := item.Object.(int64)
if !ok {
return ctx, fmt.Errorf("key=%s count not int64", key)
}

return s.getContextFromState(now, rate, item.Expiration, count), nil
}

func (s *MemoryStore) getContextFromState(now time.Time, rate Rate, expiration, count int64) Context {
remaining := int64(0)
if count < rate.Limit {
remaining = rate.Limit - count
}

expire := time.Unix(0, item.Expiration)
expire := time.Unix(0, expiration)

return Context{
Limit: rate.Limit,
Remaining: remaining,
Reset: expire.Add(time.Duration(expire.Sub(now).Seconds()) * time.Second).Unix(),
Reached: count > rate.Limit,
}, nil
}
}
59 changes: 59 additions & 0 deletions store_redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ func (s RedisStore) updateRate(c redis.Conn, key string, rate Rate) ([]int, erro
return redis.Ints(c.Do("EXEC"))
}

func (s RedisStore) getRate(c redis.Conn, key string, rate Rate) ([]int, error) {
c.Send("MULTI")
c.Send("GET", key)
c.Send("TTL", key)
return redis.Ints(c.Do("EXEC"))
}

// Get returns the limit for the identifier.
func (s RedisStore) Get(key string, rate Rate) (Context, error) {
var (
Expand Down Expand Up @@ -138,3 +145,55 @@ func (s RedisStore) Get(key string, rate Rate) (Context, error) {
Reached: count > rate.Limit,
}, nil
}

// Peek returns the limit for the identifier.
func (s RedisStore) Peek(key string, rate Rate) (Context, error) {
var (
err error
values []int
)

ctx := Context{}
key = fmt.Sprintf("%s:%s", s.Prefix, key)

c := s.Pool.Get()
defer c.Close()
if err := c.Err(); err != nil {
return Context{}, err
}

c.Do("WATCH", key)
defer c.Do("UNWATCH", key)

values, err = s.do(s.getRate, c, key, rate)
if err != nil {
return ctx, err
}

created := (values[0] == 0)
ms := int64(time.Millisecond)

if created {
return Context{
Limit: rate.Limit,
Remaining: rate.Limit,
Reset: (time.Now().UnixNano()/ms + int64(rate.Period)/ms) / 1000,
Reached: false,
}, nil
}

count := int64(values[0])
ttl := int64(values[1])
remaining := int64(0)

if count < rate.Limit {
remaining = rate.Limit - count
}

return Context{
Limit: rate.Limit,
Remaining: remaining,
Reset: time.Now().Add(time.Duration(ttl) * time.Second).Unix(),
Reached: count > rate.Limit,
}, nil
}

0 comments on commit f8bea7e

Please sign in to comment.