diff --git a/limiter.go b/limiter.go index da252e6..f135c09 100644 --- a/limiter.go +++ b/limiter.go @@ -34,3 +34,8 @@ func NewLimiter(store Store, rate Rate) *Limiter { func (l *Limiter) Get(key string) (Context, error) { return l.Store.Get(key, l.Rate) } + +// Peek returns the limit for identifier without impacting accounting +func (l *Limiter) Peek(key string) (Context, error) { + return l.Store.Peek(key, l.Rate) +} diff --git a/limiter_test.go b/limiter_test.go index 6780105..83cfb2e 100644 --- a/limiter_test.go +++ b/limiter_test.go @@ -2,6 +2,7 @@ package limiter import ( "math" + "math/rand" "testing" "time" @@ -9,6 +10,20 @@ import ( "github.com/stretchr/testify/assert" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + +var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func RandStringRunes(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + // TestLimiterMemory tests Limiter with memory store. func TestLimiterMemory(t *testing.T) { rate, err := NewRateFromFormatted("3-M") @@ -19,25 +34,7 @@ func TestLimiterMemory(t *testing.T) { CleanUpInterval: 30 * time.Second, }) - limiter := NewLimiter(store, rate) - - i := 1 - for i <= 5 { - ctx, err := limiter.Get("boo") - assert.Nil(t, err) - - if i <= 3 { - assert.Equal(t, int64(3), ctx.Limit) - assert.Equal(t, int64(3-i), ctx.Remaining) - assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60) - } else { - assert.Equal(t, int64(3), ctx.Limit) - assert.True(t, ctx.Remaining == 0) - assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60) - } - - i++ - } + testLimiter(t, store, rate) } // TestLimiterRedis tests Limiter with Redis store. @@ -45,24 +42,37 @@ func TestLimiterRedis(t *testing.T) { rate, err := NewRateFromFormatted("3-M") assert.Nil(t, err) + randPrefix := RandStringRunes(10) store, err := NewRedisStoreWithOptions( newRedisPool(), - StoreOptions{Prefix: "limitertests:redis", MaxRetry: 3}) + StoreOptions{Prefix: "limitertests:redis_" + randPrefix, MaxRetry: 3}) assert.Nil(t, err) + testLimiter(t, store, rate) +} + +func testLimiter(t *testing.T, store Store, rate Rate) { limiter := NewLimiter(store, rate) i := 1 for i <= 5 { + if i <= 3 { + ctx, err := limiter.Peek("boo") + assert.NoError(t, err) + assert.Equal(t, int64(3-(i-1)), ctx.Remaining) + } + ctx, err := limiter.Get("boo") - assert.Nil(t, err) + assert.NoError(t, err) if i <= 3 { assert.Equal(t, int64(3), ctx.Limit) assert.Equal(t, int64(3-i), ctx.Remaining) assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60) - + ctx, err := limiter.Peek("boo") + assert.NoError(t, err) + assert.Equal(t, int64(3-i), ctx.Remaining) } else { assert.Equal(t, int64(3), ctx.Limit) assert.True(t, ctx.Remaining == 0) @@ -71,6 +81,7 @@ func TestLimiterRedis(t *testing.T) { i++ } + } // ----------------------------------------------------------------------------- diff --git a/store.go b/store.go index 0075818..9ad6a74 100644 --- a/store.go +++ b/store.go @@ -5,6 +5,7 @@ import "time" // Store is the common interface for limiter stores. type Store interface { Get(key string, rate Rate) (Context, error) + Peek(key string, rate Rate) (Context, error) } // StoreOptions are options for store. diff --git a/store_memory.go b/store_memory.go index 47a30f3..15ea180 100644 --- a/store_memory.go +++ b/store_memory.go @@ -53,17 +53,47 @@ func (s *MemoryStore) Get(key string, rate Rate) (Context, error) { return ctx, err } + return s.getContextFromState(now, rate, item.Expiration, count), nil +} + +// Peek implement Store.Peek() method. +func (s *MemoryStore) Peek(key string, rate Rate) (Context, error) { + ctx := Context{} + key = fmt.Sprintf("%s:%s", s.Prefix, key) + item, found := s.Cache.Items()[key] + ms := int64(time.Millisecond) + now := time.Now() + + if !found || item.Expired() { + // new or expired should show what the values "would" be but not set cache state + return Context{ + Limit: rate.Limit, + Remaining: rate.Limit, + Reset: (now.UnixNano()/ms + int64(rate.Period)/ms) / 1000, + Reached: false, + }, nil + } + + count, ok := item.Object.(int64) + if !ok { + return ctx, fmt.Errorf("key=%s count not int64", key) + } + + return s.getContextFromState(now, rate, item.Expiration, count), nil +} + +func (s *MemoryStore) getContextFromState(now time.Time, rate Rate, expiration, count int64) Context { remaining := int64(0) if count < rate.Limit { remaining = rate.Limit - count } - expire := time.Unix(0, item.Expiration) + expire := time.Unix(0, expiration) return Context{ Limit: rate.Limit, Remaining: remaining, Reset: expire.Add(time.Duration(expire.Sub(now).Seconds()) * time.Second).Unix(), Reached: count > rate.Limit, - }, nil + } } diff --git a/store_redis.go b/store_redis.go index 317db95..038dc37 100644 --- a/store_redis.go +++ b/store_redis.go @@ -81,6 +81,13 @@ func (s RedisStore) updateRate(c redis.Conn, key string, rate Rate) ([]int, erro return redis.Ints(c.Do("EXEC")) } +func (s RedisStore) getRate(c redis.Conn, key string, rate Rate) ([]int, error) { + c.Send("MULTI") + c.Send("GET", key) + c.Send("TTL", key) + return redis.Ints(c.Do("EXEC")) +} + // Get returns the limit for the identifier. func (s RedisStore) Get(key string, rate Rate) (Context, error) { var ( @@ -138,3 +145,55 @@ func (s RedisStore) Get(key string, rate Rate) (Context, error) { Reached: count > rate.Limit, }, nil } + +// Peek returns the limit for the identifier. +func (s RedisStore) Peek(key string, rate Rate) (Context, error) { + var ( + err error + values []int + ) + + ctx := Context{} + key = fmt.Sprintf("%s:%s", s.Prefix, key) + + c := s.Pool.Get() + defer c.Close() + if err := c.Err(); err != nil { + return Context{}, err + } + + c.Do("WATCH", key) + defer c.Do("UNWATCH", key) + + values, err = s.do(s.getRate, c, key, rate) + if err != nil { + return ctx, err + } + + created := (values[0] == 0) + ms := int64(time.Millisecond) + + if created { + return Context{ + Limit: rate.Limit, + Remaining: rate.Limit, + Reset: (time.Now().UnixNano()/ms + int64(rate.Period)/ms) / 1000, + Reached: false, + }, nil + } + + count := int64(values[0]) + ttl := int64(values[1]) + remaining := int64(0) + + if count < rate.Limit { + remaining = rate.Limit - count + } + + return Context{ + Limit: rate.Limit, + Remaining: remaining, + Reset: time.Now().Add(time.Duration(ttl) * time.Second).Unix(), + Reached: count > rate.Limit, + }, nil +}