diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..61ead86 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/vendor diff --git a/.travis.yml b/.travis.yml index 819a109..b29d46c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,15 +1,25 @@ language: go before_install: - - go get github.com/stretchr/testify - - go get github.com/garyburd/redigo/redis - - go get github.com/ant0ine/go-json-rest/rest - - go get github.com/patrickmn/go-cache + - go get -u github.com/golang/dep/cmd/dep + - dep ensure + - go get -u github.com/alecthomas/gometalinter + - gometalinter --install go: - - 1.4 - - 1.5 - - 1.6 - - 1.7 - 1.8 -script: make test + - 1.8.1 + - 1.8.2 + - 1.8.3 + - 1.9 + - 1.9.1 + - tip +script: + - make test + - make lint services: - redis-server +env: + - REDIS_DISABLE_BOOTSTRAP=true +matrix: + fast_finish: true + allow_failures: + - go: tip diff --git a/AUTHORS b/AUTHORS index 8cfdc14..c4be890 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,5 @@ Primary contributors: - Gilles Fabio - Florent Messa + Gilles FABIO + Florent MESSA + Thomas LE ROUX diff --git a/Gopkg.lock b/Gopkg.lock new file mode 100644 index 0000000..183eb36 --- /dev/null +++ b/Gopkg.lock @@ -0,0 +1,87 @@ +# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. + + +[[projects]] + name = "github.com/davecgh/go-spew" + packages = ["spew"] + revision = "346938d642f2ec3594ed81d874461961cd0faa76" + version = "v1.1.0" + +[[projects]] + branch = "master" + name = "github.com/gin-contrib/sse" + packages = ["."] + revision = "22d885f9ecc78bf4ee5d72b937e4bbcdc58e8cae" + +[[projects]] + name = "github.com/gin-gonic/gin" + packages = [".","binding","render"] + revision = "d459835d2b077e44f7c9b453505ee29881d5d12d" + version = "v1.2" + +[[projects]] + name = "github.com/go-redis/redis" + packages = [".","internal","internal/consistenthash","internal/hashtag","internal/pool","internal/proto"] + revision = "19c1c2272e00c1aaa903cf574c746cd449f9cd3c" + version = "v6.5.7" + +[[projects]] + branch = "master" + name = "github.com/golang/protobuf" + packages = ["proto"] + revision = "ab9f9a6dab164b7d1246e0e688b0ab7b94d8553e" + +[[projects]] + name = "github.com/mattn/go-isatty" + packages = ["."] + revision = "fc9e8d8ef48496124e79ae0df75490096eccf6fe" + version = "v0.0.2" + +[[projects]] + name = "github.com/pkg/errors" + packages = ["."] + revision = "645ef00459ed84a119197bfb8d8205042c6df63d" + version = "v0.8.0" + +[[projects]] + name = "github.com/pmezard/go-difflib" + packages = ["difflib"] + revision = "792786c7400a136282c1664665ae0a8db921c6c2" + version = "v1.0.0" + +[[projects]] + name = "github.com/stretchr/testify" + packages = ["assert","require"] + revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" + version = "v1.1.4" + +[[projects]] + branch = "master" + name = "github.com/ugorji/go" + packages = ["codec"] + revision = "5efa3251c7f7d05e5d9704a69a984ec9f1386a40" + +[[projects]] + branch = "master" + name = "golang.org/x/sys" + packages = ["unix"] + revision = "43e60d72a8e2bd92ee98319ba9a384a0e9837c08" + +[[projects]] + name = "gopkg.in/go-playground/validator.v8" + packages = ["."] + revision = "5f1438d3fca68893a817e4a66806cea46a9e4ebf" + version = "v8.18.2" + +[[projects]] + branch = "v2" + name = "gopkg.in/yaml.v2" + packages = ["."] + revision = "eb3733d160e74a9c7e442f435eb3bea458e1d19f" + +[solve-meta] + analyzer-name = "dep" + analyzer-version = 1 + inputs-digest = "0a511933063b3e715bcad7b7d7da2e0590a20d3b8a23786189d955ad79a71a97" + solver-name = "gps-cdcl" + solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml new file mode 100644 index 0000000..fdedf37 --- /dev/null +++ b/Gopkg.toml @@ -0,0 +1,13 @@ +# Gopkg.toml for github.com/ulule/limiter + +[[constraint]] + name = "github.com/pkg/errors" + version = "0.8.0" + +[[constraint]] + name = "github.com/go-redis/redis" + version = "6.5.6" + +[[constraint]] + name = "github.com/gin-gonic/gin" + version = "v1.2" diff --git a/Makefile b/Makefile index aca32fb..374cc20 100644 --- a/Makefile +++ b/Makefile @@ -4,4 +4,7 @@ cleandb: @(redis-cli KEYS "limitertests:*" | xargs redis-cli DEL) test: cleandb - @(go test -v -run ^Test) + @(scripts/test) + +lint: + @(scripts/lint) diff --git a/README.md b/README.md index 058ff8b..942c524 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ * Simple API * "Store" approach for backend * Redis support (but not tied too) -* Middlewares: HTTP and [go-json-rest][2] +* Middlewares: HTTP and [Gin][4] ## Installation @@ -22,20 +22,22 @@ $ go get github.com/ulule/limiter In five steps: -* Create a `limiter.Rate` instance (the number of requests per period) -* Create a `limiter.Store` instance (see [store_redis](https://github.com/ulule/limiter/blob/master/store_redis.go) for Redis or [store_memory](https://github.com/ulule/limiter/blob/master/store_memory.go) for in-memory) +* Create a `limiter.Rate` instance _(the number of requests per period)_ +* Create a `limiter.Store` instance _(see [Redis](https://github.com/ulule/limiter/blob/master/drivers/store/redis/store.go) or [In-Memory](https://github.com/ulule/limiter/blob/master/drivers/store/memory/store.go))_ * Create a `limiter.Limiter` instance that takes store and rate instances as arguments * Create a middleware instance using the middleware of your choice * Give the limiter instance to your middleware initializer -Example: +**Example:** ```go // Create a rate with the given limit (number of requests) for the given // period (a time.Duration of your choice). +import "github.com/ulule/limiter" + rate := limiter.Rate{ Period: 1 * time.Hour, - Limit: int64(1000), + Limit: 1000, } // You can also use the simplified format "-"", with the given @@ -60,34 +62,39 @@ if err != nil { // compliant to limiter.Store interface will do the job. The defaults are // "limiter" as Redis key prefix and a maximum of 3 retries for the key under // race condition. -store, err := limiter.NewRedisStore(pool) +import "github.com/ulule/limiter/drivers/store/redis" + +store, err := redis.NewStore(client) if err != nil { panic(err) } // Alternatively, you can pass options to the store with the "WithOptions" // function. For example, for Redis store: -store, err := limiter.NewRedisStoreWithOptions(pool, limiter.StoreOptions{ +import "github.com/ulule/limiter/drivers/store/redis" + +store, err := redis.NewStoreWithOptions(pool, limiter.StoreOptions{ Prefix: "your_own_prefix", MaxRetry: 4, }) - if err != nil { panic(err) } -// Or use a in-memory store with a goroutine which clears expired keys every 30 seconds -store := limiter.NewMemoryStore("prefix_for_keys", 30*time.Second) +// Or use a in-memory store with a goroutine which clears expired keys. +import "github.com/ulule/limiter/drivers/store/memory" + +store := memory.NewStore() // Then, create the limiter instance which takes the store and the rate as arguments. // Now, you can give this instance to any supported middleware. -limiterInstance := limiter.NewLimiter(store, rate) +instance := limiter.New(store, rate) ``` See middleware examples: -* [HTTP](https://github.com/ulule/limiter/tree/master/examples/http) -* [go-json-rest](https://github.com/ulule/limiter/tree/master/examples/gjr) +* [HTTP](https://github.com/ulule/limiter/tree/master/examples/http/main.go) +* [Gin](https://github.com/ulule/limiter/tree/master/examples/gin/main.go) ## How it works @@ -98,10 +105,10 @@ value with an expiration period. You will find two stores: -* RedisStore: rely on [TTL](http://redis.io/commands/ttl) and incrementing the rate limit on each request -* MemoryStore: rely on [go-cache](https://github.com/patrickmn/go-cache) with a goroutine to clear expired keys using a default interval +* Redis: rely on [TTL](http://redis.io/commands/ttl) and incrementing the rate limit on each request. +* In-Memory: rely on a fork of [go-cache](https://github.com/patrickmn/go-cache) with a goroutine to clear expired keys using a default interval. -When the limit is reached, a ``429`` HTTP code is sent. +When the limit is reached, a `429` HTTP status code is sent. ## Why Yet Another Package @@ -118,7 +125,7 @@ number of bytes uploaded"*. It is brillant in term of algorithm but documentation is quite unclear at the moment, we don't need *burst* feature for now, impossible to get a correct `After-Retry` (when limit exceeds, we can still make a few requests, because of the max burst) and it only supports ``http.Handler`` -middleware (we use [go-json-rest][2]). Currently, we only need to return `429` +middleware (we use [Gin][4]). Currently, we only need to return `429` and `X-Ratelimit-*` headers for `n reqs/duration`. 2. [Speedbump][3]. Good package but maybe too lightweight. No `Reset` support, @@ -131,7 +138,7 @@ provide any Redis support (only *in-memory*) and a ready-to-go middleware that s `X-Ratelimit-*` headers. `tollbooth.LimitByRequest(limiter, r)` only returns an HTTP code. -4. [ratelimit][6]. Probably the closer to our needs but, once again, too +4. [ratelimit][2]. Probably the closer to our needs but, once again, too lightweight, no middleware available and not active (last commit was in August 2014). Some parts of code (Redis) comes from this project. It should deserve much more love. @@ -142,18 +149,20 @@ create yet another one. ## Contributing -* Ping us on twitter [@oibafsellig](https://twitter.com/oibafsellig), [@thoas](https://twitter.com/thoas) +* Ping us on twitter: + * [@oibafsellig](https://twitter.com/oibafsellig) + * [@thoas](https://twitter.com/thoas) + * [@novln_](https://twitter.com/novln_) * Fork the [project](https://github.com/ulule/limiter) * Fix [bugs](https://github.com/ulule/limiter/issues) Don't hesitate ;) [1]: https://github.com/throttled/throttled -[2]: https://github.com/ant0ine/go-json-rest +[2]: https://github.com/r8k/ratelimit [3]: https://github.com/etcinit/speedbump [4]: https://github.com/gin-gonic/gin [5]: https://github.com/didip/tollbooth -[6]: https://github.com/r8k/ratelimit [godoc-url]: https://godoc.org/github.com/ulule/limiter [godoc-img]: https://godoc.org/github.com/ulule/limiter?status.svg diff --git a/drivers/middleware/gin/middleware.go b/drivers/middleware/gin/middleware.go new file mode 100644 index 0000000..568088a --- /dev/null +++ b/drivers/middleware/gin/middleware.go @@ -0,0 +1,55 @@ +package gin + +import ( + "strconv" + + "github.com/gin-gonic/gin" + + "github.com/ulule/limiter" +) + +// Middleware is the middleware for basic http.Handler. +type Middleware struct { + Limiter *limiter.Limiter + OnError ErrorHandler + OnLimitReached LimitReachedHandler +} + +// NewMiddleware return a new instance of a basic HTTP middleware. +func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc { + middleware := &Middleware{ + Limiter: limiter, + OnError: DefaultErrorHandler, + OnLimitReached: DefaultLimitReachedHandler, + } + + for _, option := range options { + option.apply(middleware) + } + + return func(ctx *gin.Context) { + middleware.Handle(ctx) + } +} + +// Handle gin request. +func (middleware *Middleware) Handle(c *gin.Context) { + context, err := middleware.Limiter.Get(c, limiter.GetIPKey(c.Request)) + if err != nil { + middleware.OnError(c, err) + c.Abort() + return + } + + c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) + c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) + c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) + + if context.Reached { + middleware.OnLimitReached(c) + c.Abort() + return + } + + c.Next() +} diff --git a/drivers/middleware/gin/middleware_test.go b/drivers/middleware/gin/middleware_test.go new file mode 100644 index 0000000..cb71441 --- /dev/null +++ b/drivers/middleware/gin/middleware_test.go @@ -0,0 +1,98 @@ +package gin_test + +import ( + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + libgin "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/middleware/gin" + "github.com/ulule/limiter/drivers/store/memory" +) + +func TestHTTPMiddleware(t *testing.T) { + is := require.New(t) + libgin.SetMode(libgin.TestMode) + + request, err := http.NewRequest("GET", "/", nil) + is.NoError(err) + is.NotNil(request) + + store := memory.NewStore() + is.NotZero(store) + + rate, err := limiter.NewRateFromFormatted("10-M") + is.NoError(err) + is.NotZero(rate) + + middleware := gin.NewMiddleware(limiter.New(store, rate)) + is.NotZero(middleware) + + router := libgin.New() + router.Use(middleware) + router.GET("/", func(c *libgin.Context) { + c.String(http.StatusOK, "hello") + }) + + success := int64(10) + clients := int64(100) + + // + // Sequential + // + + for i := int64(1); i <= clients; i++ { + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, request) + + if i <= success { + is.Equal(resp.Code, http.StatusOK) + } else { + is.Equal(resp.Code, http.StatusTooManyRequests) + } + } + + // + // Concurrent + // + + store = memory.NewStore() + is.NotZero(store) + + middleware = gin.NewMiddleware(limiter.New(store, rate)) + is.NotZero(middleware) + + router = libgin.New() + router.Use(middleware) + router.GET("/", func(c *libgin.Context) { + c.String(http.StatusOK, "hello") + }) + + wg := &sync.WaitGroup{} + counter := int64(0) + + for i := int64(1); i <= clients; i++ { + wg.Add(1) + go func() { + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, request) + + if resp.Code == http.StatusOK { + atomic.AddInt64(&counter, 1) + } + + wg.Done() + }() + } + + wg.Wait() + is.Equal(success, atomic.LoadInt64(&counter)) + +} diff --git a/drivers/middleware/gin/options.go b/drivers/middleware/gin/options.go new file mode 100644 index 0000000..ff63dcd --- /dev/null +++ b/drivers/middleware/gin/options.go @@ -0,0 +1,48 @@ +package gin + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// Option is used to define Middleware configuration. +type Option interface { + apply(*Middleware) +} + +type option func(*Middleware) + +func (o option) apply(middleware *Middleware) { + o(middleware) +} + +// ErrorHandler is an handler used to inform when an error has occurred. +type ErrorHandler func(c *gin.Context, err error) + +// WithErrorHandler will configure the Middleware to use the given ErrorHandler. +func WithErrorHandler(handler ErrorHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnError = handler + }) +} + +// DefaultErrorHandler is the default ErrorHandler used by a new Middleware. +func DefaultErrorHandler(c *gin.Context, err error) { + panic(err) +} + +// LimitReachedHandler is an handler used to inform when the limit has exceeded. +type LimitReachedHandler func(c *gin.Context) + +// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. +func WithLimitReachedHandler(handler LimitReachedHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnLimitReached = handler + }) +} + +// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. +func DefaultLimitReachedHandler(c *gin.Context) { + c.String(http.StatusTooManyRequests, "Limit exceeded") +} diff --git a/drivers/middleware/stdlib/middleware.go b/drivers/middleware/stdlib/middleware.go new file mode 100644 index 0000000..b0cd79b --- /dev/null +++ b/drivers/middleware/stdlib/middleware.go @@ -0,0 +1,52 @@ +package stdlib + +import ( + "net/http" + "strconv" + + "github.com/ulule/limiter" +) + +// Middleware is the middleware for basic http.Handler. +type Middleware struct { + Limiter *limiter.Limiter + OnError ErrorHandler + OnLimitReached LimitReachedHandler +} + +// NewMiddleware return a new instance of a basic HTTP middleware. +func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware { + middleware := &Middleware{ + Limiter: limiter, + OnError: DefaultErrorHandler, + OnLimitReached: DefaultLimitReachedHandler, + } + + for _, option := range options { + option.apply(middleware) + } + + return middleware +} + +// Handler the middleware handler. +func (middleware *Middleware) Handler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + context, err := middleware.Limiter.Get(r.Context(), limiter.GetIPKey(r)) + if err != nil { + middleware.OnError(w, r, err) + return + } + + w.Header().Add("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) + w.Header().Add("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) + w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) + + if context.Reached { + middleware.OnLimitReached(w, r) + return + } + + h.ServeHTTP(w, r) + }) +} diff --git a/drivers/middleware/stdlib/middleware_test.go b/drivers/middleware/stdlib/middleware_test.go new file mode 100644 index 0000000..40f79e1 --- /dev/null +++ b/drivers/middleware/stdlib/middleware_test.go @@ -0,0 +1,91 @@ +package stdlib_test + +import ( + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/middleware/stdlib" + "github.com/ulule/limiter/drivers/store/memory" +) + +func TestHTTPMiddleware(t *testing.T) { + is := require.New(t) + + request, err := http.NewRequest("GET", "/", nil) + is.NoError(err) + is.NotNil(request) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, thr := w.Write([]byte("hello")) + if thr != nil { + panic(thr) + } + }) + + store := memory.NewStore() + is.NotZero(store) + + rate, err := limiter.NewRateFromFormatted("10-M") + is.NoError(err) + is.NotZero(rate) + + middleware := stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler) + is.NotZero(middleware) + + success := int64(10) + clients := int64(100) + + // + // Sequential + // + + for i := int64(1); i <= clients; i++ { + + resp := httptest.NewRecorder() + middleware.ServeHTTP(resp, request) + + if i <= success { + is.Equal(resp.Code, http.StatusOK) + } else { + is.Equal(resp.Code, http.StatusTooManyRequests) + } + } + + // + // Concurrent + // + + store = memory.NewStore() + is.NotZero(store) + + middleware = stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler) + is.NotZero(middleware) + + wg := &sync.WaitGroup{} + counter := int64(0) + + for i := int64(1); i <= clients; i++ { + wg.Add(1) + go func() { + + resp := httptest.NewRecorder() + middleware.ServeHTTP(resp, request) + + if resp.Code == http.StatusOK { + atomic.AddInt64(&counter, 1) + } + + wg.Done() + }() + } + + wg.Wait() + is.Equal(success, atomic.LoadInt64(&counter)) + +} diff --git a/drivers/middleware/stdlib/options.go b/drivers/middleware/stdlib/options.go new file mode 100644 index 0000000..fd99eaa --- /dev/null +++ b/drivers/middleware/stdlib/options.go @@ -0,0 +1,46 @@ +package stdlib + +import ( + "net/http" +) + +// Option is used to define Middleware configuration. +type Option interface { + apply(*Middleware) +} + +type option func(*Middleware) + +func (o option) apply(middleware *Middleware) { + o(middleware) +} + +// ErrorHandler is an handler used to inform when an error has occurred. +type ErrorHandler func(w http.ResponseWriter, r *http.Request, err error) + +// WithErrorHandler will configure the Middleware to use the given ErrorHandler. +func WithErrorHandler(handler ErrorHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnError = handler + }) +} + +// DefaultErrorHandler is the default ErrorHandler used by a new Middleware. +func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) { + panic(err) +} + +// LimitReachedHandler is an handler used to inform when the limit has exceeded. +type LimitReachedHandler func(w http.ResponseWriter, r *http.Request) + +// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. +func WithLimitReachedHandler(handler LimitReachedHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnLimitReached = handler + }) +} + +// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. +func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Limit exceeded", http.StatusTooManyRequests) +} diff --git a/drivers/store/common/context.go b/drivers/store/common/context.go new file mode 100644 index 0000000..4c861f7 --- /dev/null +++ b/drivers/store/common/context.go @@ -0,0 +1,28 @@ +package common + +import ( + "time" + + "github.com/ulule/limiter" +) + +// GetContextFromState generate a new limiter.Context from given state. +func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context { + limit := rate.Limit + remaining := int64(0) + reached := true + + if count <= limit { + remaining = limit - count + reached = false + } + + reset := expiration.Unix() + + return limiter.Context{ + Limit: limit, + Remaining: remaining, + Reset: reset, + Reached: reached, + } +} diff --git a/drivers/store/common/tests.go b/drivers/store/common/tests.go new file mode 100644 index 0000000..4619d63 --- /dev/null +++ b/drivers/store/common/tests.go @@ -0,0 +1,86 @@ +package common + +import ( + "context" + "math" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter" +) + +// TestStoreSequentialAccess verify that store works as expected with a sequential access. +func TestStoreSequentialAccess(t *testing.T, store limiter.Store) { + is := require.New(t) + ctx := context.Background() + + limiter := limiter.New(store, limiter.Rate{ + Limit: 3, + Period: time.Minute, + }) + + for i := 1; i <= 6; i++ { + + if i <= 3 { + + lctx, err := limiter.Peek(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + is.Equal(int64(3-(i-1)), lctx.Remaining) + + } + + lctx, err := limiter.Get(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + + if i <= 3 { + + is.Equal(int64(3), lctx.Limit) + is.Equal(int64(3-i), lctx.Remaining) + is.True(math.Ceil(time.Since(time.Unix(lctx.Reset, 0)).Seconds()) <= 60) + + lctx, err = limiter.Peek(ctx, "foo") + is.NoError(err) + is.Equal(int64(3-i), lctx.Remaining) + + } else { + + is.Equal(int64(3), lctx.Limit) + is.True(lctx.Remaining == 0) + is.True(math.Ceil(time.Since(time.Unix(lctx.Reset, 0)).Seconds()) <= 60) + + } + } +} + +// TestStoreConcurrentAccess verify that store works as expected with a concurrent access. +func TestStoreConcurrentAccess(t *testing.T, store limiter.Store) { + is := require.New(t) + ctx := context.Background() + + limiter := limiter.New(store, limiter.Rate{ + Limit: 100000, + Period: 10 * time.Minute, + }) + + goroutines := 100 + ops := 200 + + wg := &sync.WaitGroup{} + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func(i int) { + for j := 0; j < ops; j++ { + lctx, err := limiter.Get(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + } + wg.Done() + }(i) + } + wg.Wait() +} diff --git a/drivers/store/memory/cache.go b/drivers/store/memory/cache.go new file mode 100644 index 0000000..89a8e18 --- /dev/null +++ b/drivers/store/memory/cache.go @@ -0,0 +1,149 @@ +package memory + +import ( + "runtime" + "sync" + "time" +) + +// Forked from https://github.com/patrickmn/go-cache + +// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent +// Cache from being garbage collected. +type CacheWrapper struct { + *Cache +} + +// A cleaner will periodically delete expired keys from cache. +type cleaner struct { + interval time.Duration + stop chan bool +} + +// Run will periodically delete expired keys from given cache until GC notify that it should stop. +func (cleaner *cleaner) Run(cache *Cache) { + ticker := time.NewTicker(cleaner.interval) + for { + select { + case <-ticker.C: + cache.Clean() + case <-cleaner.stop: + ticker.Stop() + return + } + } +} + +// stopCleaner is a callback from GC used to stop cleaner goroutine. +func stopCleaner(wrapper *CacheWrapper) { + wrapper.cleaner.stop <- true +} + +// startCleaner will start a cleaner goroutine for given cache. +func startCleaner(cache *Cache, interval time.Duration) { + cleaner := &cleaner{ + interval: interval, + stop: make(chan bool), + } + + cache.cleaner = cleaner + go cleaner.Run(cache) +} + +// Counter is a simple counter with an optional expiration. +type Counter struct { + Value int64 + Expiration int64 +} + +// Expired returns true if the counter has expired. +func (counter Counter) Expired() bool { + if counter.Expiration == 0 { + return false + } + return time.Now().UnixNano() > counter.Expiration +} + +// Cache contains a collection of counters. +type Cache struct { + mutex sync.RWMutex + counters map[string]Counter + cleaner *cleaner +} + +// NewCache returns a new cache. +func NewCache(cleanInterval time.Duration) *CacheWrapper { + + cache := &Cache{ + counters: map[string]Counter{}, + } + + wrapper := &CacheWrapper{cache} + + if cleanInterval > 0 { + startCleaner(cache, cleanInterval) + runtime.SetFinalizer(wrapper, stopCleaner) + } + + return wrapper +} + +// Increment increments given value on key. +// If key is undefined or expired, it will create it. +func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) { + cache.mutex.Lock() + + counter, ok := cache.counters[key] + if !ok || counter.Expired() { + expiration := time.Now().Add(duration).UnixNano() + counter = Counter{ + Value: value, + Expiration: expiration, + } + + cache.counters[key] = counter + cache.mutex.Unlock() + + return value, time.Unix(0, expiration) + } + + value = counter.Value + value + counter.Value = value + expiration := counter.Expiration + + cache.counters[key] = counter + cache.mutex.Unlock() + + return value, time.Unix(0, expiration) +} + +// Get returns key's value and expiration. +func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) { + cache.mutex.RLock() + + counter, ok := cache.counters[key] + if !ok || counter.Expired() { + expiration := time.Now().Add(duration).UnixNano() + cache.mutex.RUnlock() + return 0, time.Unix(0, expiration) + } + + value := counter.Value + expiration := counter.Expiration + cache.mutex.RUnlock() + + return value, time.Unix(0, expiration) +} + +// Clean will deleted any expired keys. +func (cache *Cache) Clean() { + now := time.Now().UnixNano() + + cache.mutex.Lock() + for key, counter := range cache.counters { + if now > counter.Expiration { + delete(cache.counters, key) + } + } + cache.mutex.Unlock() +} diff --git a/drivers/store/memory/cache_test.go b/drivers/store/memory/cache_test.go new file mode 100644 index 0000000..8752dff --- /dev/null +++ b/drivers/store/memory/cache_test.go @@ -0,0 +1,96 @@ +package memory_test + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter/drivers/store/memory" +) + +func TestCacheIncrementSequential(t *testing.T) { + is := require.New(t) + + key := "foobar" + cache := memory.NewCache(10 * time.Nanosecond) + duration := 50 * time.Millisecond + deleted := time.Now().Add(duration).UnixNano() + epsilon := 0.001 + + x, expire := cache.Increment(key, 1, duration) + is.Equal(int64(1), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Increment(key, 2, duration) + is.Equal(int64(3), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + time.Sleep(duration) + + deleted = time.Now().Add(duration).UnixNano() + x, expire = cache.Increment(key, 1, duration) + is.Equal(int64(1), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) +} + +func TestCacheIncrementConcurrent(t *testing.T) { + is := require.New(t) + + goroutines := 500 + ops := 500 + + expected := int64(0) + for i := 0; i < goroutines; i++ { + if (i % 3) == 0 { + for j := 0; j < ops; j++ { + expected += int64(i + j) + } + } + } + + key := "foobar" + cache := memory.NewCache(10 * time.Nanosecond) + + wg := &sync.WaitGroup{} + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(i int) { + if (i % 3) == 0 { + time.Sleep(2 * time.Second) + for j := 0; j < ops; j++ { + cache.Increment(key, int64(i+j), (1 * time.Second)) + } + } else { + time.Sleep(50 * time.Millisecond) + stopAt := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(stopAt) { + cache.Increment(key, int64(i), (75 * time.Millisecond)) + } + } + wg.Done() + }(i) + } + wg.Wait() + + value, expire := cache.Get(key, (100 * time.Millisecond)) + is.Equal(expected, value) + is.True(time.Now().Before(expire)) +} + +func TestCacheGet(t *testing.T) { + is := require.New(t) + + key := "foobar" + cache := memory.NewCache(10 * time.Nanosecond) + duration := 50 * time.Millisecond + deleted := time.Now().Add(duration).UnixNano() + epsilon := 0.001 + + x, expire := cache.Get(key, duration) + is.Equal(int64(0), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + +} diff --git a/drivers/store/memory/store.go b/drivers/store/memory/store.go new file mode 100644 index 0000000..1b5c0ad --- /dev/null +++ b/drivers/store/memory/store.go @@ -0,0 +1,56 @@ +package memory + +import ( + "context" + "fmt" + "time" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/store/common" +) + +// Store is the in-memory store. +type Store struct { + // Prefix used for the key. + Prefix string + // cache used to store values in-memory. + cache *CacheWrapper +} + +// NewStore creates a new instance of memory store with defaults. +func NewStore() limiter.Store { + return NewStoreWithOptions(limiter.StoreOptions{ + Prefix: limiter.DefaultPrefix, + CleanUpInterval: limiter.DefaultCleanUpInterval, + }) +} + +// NewStoreWithOptions creates a new instance of memory store with options. +func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store { + return &Store{ + Prefix: options.Prefix, + cache: NewCache(options.CleanUpInterval), + } +} + +// Get returns the limit for given identifier. +func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + count, expiration := store.cache.Increment(key, 1, rate.Period) + + lctx := common.GetContextFromState(now, rate, expiration, count) + return lctx, nil +} + +// Peek returns the limit for given identifier, without modification on current values. +func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + count, expiration := store.cache.Get(key, rate.Period) + + lctx := common.GetContextFromState(now, rate, expiration, count) + return lctx, nil +} diff --git a/drivers/store/memory/store_test.go b/drivers/store/memory/store_test.go new file mode 100644 index 0000000..7328af1 --- /dev/null +++ b/drivers/store/memory/store_test.go @@ -0,0 +1,24 @@ +package memory_test + +import ( + "testing" + "time" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/store/common" + "github.com/ulule/limiter/drivers/store/memory" +) + +func TestMemoryStoreSequentialAccess(t *testing.T) { + common.TestStoreSequentialAccess(t, memory.NewStoreWithOptions(limiter.StoreOptions{ + Prefix: "limiter:memory:sequential", + CleanUpInterval: 30 * time.Second, + })) +} + +func TestMemoryStoreConcurrentAccess(t *testing.T) { + common.TestStoreConcurrentAccess(t, memory.NewStoreWithOptions(limiter.StoreOptions{ + Prefix: "limiter:memory:concurrent", + CleanUpInterval: 1 * time.Nanosecond, + })) +} diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go new file mode 100644 index 0000000..d55e767 --- /dev/null +++ b/drivers/store/redis/store.go @@ -0,0 +1,260 @@ +package redis + +import ( + "context" + "fmt" + "time" + + libredis "github.com/go-redis/redis" + "github.com/pkg/errors" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/store/common" +) + +// Client is an interface thats allows to use a redis cluster or a redis single client seamlessly. +type Client interface { + Ping() *libredis.StatusCmd + Get(key string) *libredis.StringCmd + Set(key string, value interface{}, expiration time.Duration) *libredis.StatusCmd + Watch(handler func(*libredis.Tx) error, keys ...string) error + Del(keys ...string) *libredis.IntCmd + SetNX(key string, value interface{}, expiration time.Duration) *libredis.BoolCmd + Eval(script string, keys []string, args ...interface{}) *libredis.Cmd +} + +// Store is the redis store. +type Store struct { + // Prefix used for the key. + Prefix string + // MaxRetry is the maximum number of retry under race conditions. + MaxRetry int + // client used to communicate with redis server. + client Client +} + +// NewStore returns an instance of redis store with defaults. +func NewStore(client Client) (limiter.Store, error) { + return NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: limiter.DefaultPrefix, + CleanUpInterval: limiter.DefaultCleanUpInterval, + }) +} + +// NewStoreWithOptions returns an instance of redis store with options. +func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) { + store := &Store{ + client: client, + Prefix: options.Prefix, + MaxRetry: options.MaxRetry, + } + + if store.MaxRetry <= 0 { + store.MaxRetry = 1 + } + + _, err := store.ping() + if err != nil { + return nil, err + } + + return store, nil +} + +// Get returns the limit for given identifier. +func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + lctx := limiter.Context{} + onWatch := func(rtx *libredis.Tx) error { + + created, err := store.doSetValue(rtx, key, rate.Period) + if err != nil { + return err + } + + if created { + expiration := now.Add(rate.Period) + lctx = common.GetContextFromState(now, rate, expiration, 1) + return nil + } + + count, ttl, err := store.doUpdateValue(rtx, key, rate.Period) + if err != nil { + return err + } + + expiration := now.Add(rate.Period) + if ttl > 0 { + expiration = now.Add(ttl) + } + + lctx = common.GetContextFromState(now, rate, expiration, count) + return nil + } + + err := store.client.Watch(onWatch, key) + if err != nil { + err = errors.Wrapf(err, "limiter: cannot get value for %s", key) + return limiter.Context{}, err + } + + return lctx, nil +} + +// Peek returns the limit for given identifier, without modification on current values. +func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + lctx := limiter.Context{} + onWatch := func(rtx *libredis.Tx) error { + count, ttl, err := store.doPeekValue(rtx, key) + if err != nil { + return err + } + + expiration := now.Add(rate.Period) + if ttl > 0 { + expiration = now.Add(ttl) + } + + lctx = common.GetContextFromState(now, rate, expiration, count) + return nil + } + + err := store.client.Watch(onWatch, key) + if err != nil { + err = errors.Wrapf(err, "limiter: cannot peek value for %s", key) + return limiter.Context{}, err + } + + return lctx, nil +} + +// doPeekValue will execute peekValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. +func (store *Store) doPeekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) { + for i := 0; i < store.MaxRetry; i++ { + count, ttl, err := peekValue(rtx, key) + if err == nil { + return count, ttl, nil + } + } + return 0, 0, errors.New("retry limit exceeded") +} + +// peekValue will retrieve the counter and its expiration for given key. +func peekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) { + pipe := rtx.Pipeline() + value := pipe.Get(key) + expire := pipe.PTTL(key) + + _, err := pipe.Exec() + if err != nil && err != libredis.Nil { + return 0, 0, err + } + + count, err := value.Int64() + if err != nil && err != libredis.Nil { + return 0, 0, err + } + + ttl, err := expire.Result() + if err != nil { + return 0, 0, err + } + + return count, ttl, nil +} + +// doSetValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. +func (store *Store) doSetValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) { + for i := 0; i < store.MaxRetry; i++ { + created, err := setValue(rtx, key, expiration) + if err == nil { + return created, nil + } + } + return false, errors.New("retry limit exceeded") +} + +// setValue will try to initialize a new counter if given key doesn't exists. +func setValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) { + value := rtx.SetNX(key, 1, expiration) + + created, err := value.Result() + if err != nil { + return false, err + } + + return created, nil +} + +// doUpdateValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. +func (store *Store) doUpdateValue(rtx *libredis.Tx, key string, + expiration time.Duration) (int64, time.Duration, error) { + for i := 0; i < store.MaxRetry; i++ { + count, ttl, err := updateValue(rtx, key, expiration) + if err == nil { + return count, ttl, nil + } + + // If ttl is negative and there is an error, do not retry an update. + if ttl < 0 { + return 0, 0, err + } + } + return 0, 0, errors.New("retry limit exceeded") +} + +// updateValue will try to increment the counter identified by given key. +func updateValue(rtx *libredis.Tx, key string, expiration time.Duration) (int64, time.Duration, error) { + pipe := rtx.Pipeline() + value := pipe.Incr(key) + expire := pipe.PTTL(key) + + _, err := pipe.Exec() + if err != nil { + return 0, 0, err + } + + count, err := value.Result() + if err != nil { + return 0, 0, err + } + + ttl, err := expire.Result() + if err != nil { + return 0, 0, err + } + + // If ttl is negative, we have to define key expiration. + if ttl < 0 { + expire := rtx.Expire(key, expiration) + + ok, err := expire.Result() + if err != nil { + return count, ttl, err + } + + if !ok { + return count, ttl, errors.New("cannot configure timeout on key") + } + } + + return count, ttl, nil + +} + +// ping checks if redis is alive. +func (store *Store) ping() (bool, error) { + cmd := store.client.Ping() + + pong, err := cmd.Result() + if err != nil { + return false, errors.Wrap(err, "limiter: cannot ping redis server") + } + + return (pong == "PONG"), nil +} diff --git a/drivers/store/redis/store_test.go b/drivers/store/redis/store_test.go new file mode 100644 index 0000000..77c7c6c --- /dev/null +++ b/drivers/store/redis/store_test.go @@ -0,0 +1,62 @@ +package redis_test + +import ( + "os" + "testing" + + libredis "github.com/go-redis/redis" + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/store/common" + "github.com/ulule/limiter/drivers/store/redis" +) + +func TestRedisStoreSequentialAccess(t *testing.T) { + is := require.New(t) + + client, err := newRedisClient() + is.NoError(err) + is.NotNil(client) + + store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter:redis:sequential", + MaxRetry: 3, + }) + is.NoError(err) + is.NotNil(store) + + common.TestStoreSequentialAccess(t, store) +} + +func TestRedisStoreConcurrentAccess(t *testing.T) { + is := require.New(t) + + client, err := newRedisClient() + is.NoError(err) + is.NotNil(client) + + store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter:redis:concurrent", + MaxRetry: 7, + }) + is.NoError(err) + is.NotNil(store) + + common.TestStoreConcurrentAccess(t, store) +} + +func newRedisClient() (*libredis.Client, error) { + uri := "redis://localhost:6379/0" + if os.Getenv("REDIS_URI") != "" { + uri = os.Getenv("REDIS_URI") + } + + opt, err := libredis.ParseURL(uri) + if err != nil { + return nil, err + } + + client := libredis.NewClient(opt) + return client, nil +} diff --git a/examples/gin/main.go b/examples/gin/main.go new file mode 100644 index 0000000..80ef36e --- /dev/null +++ b/examples/gin/main.go @@ -0,0 +1,57 @@ +package main + +import ( + "log" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/go-redis/redis" + "github.com/ulule/limiter" + mgin "github.com/ulule/limiter/drivers/middleware/gin" + sredis "github.com/ulule/limiter/drivers/store/redis" +) + +func main() { + + // Define a limit rate to 4 requests per hour. + rate, err := limiter.NewRateFromFormatted("4-H") + if err != nil { + log.Fatal(err) + return + } + + // Create a redis client. + option, err := redis.ParseURL("redis://localhost:6379/0") + if err != nil { + log.Fatal(err) + return + } + client := redis.NewClient(option) + + // Create a store with the redis client. + store, err := sredis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter_gin_example", + MaxRetry: 3, + }) + if err != nil { + log.Fatal(err) + return + } + + // Create a new middleware with the limiter instance. + middleware := mgin.NewMiddleware(limiter.New(store, rate)) + + // Launch a simple server. + router := gin.Default() + router.Use(middleware) + router.GET("/", index) + log.Fatal(router.Run(":7777")) +} + +func index(c *gin.Context) { + type message struct { + Message string `json:"message"` + } + resp := message{Message: "ok"} + c.JSON(http.StatusOK, resp) +} diff --git a/examples/gjr/main.go b/examples/gjr/main.go deleted file mode 100644 index 6d5f54d..0000000 --- a/examples/gjr/main.go +++ /dev/null @@ -1,53 +0,0 @@ -package main - -import ( - "fmt" - "log" - "net/http" - - "github.com/ant0ine/go-json-rest/rest" - "github.com/garyburd/redigo/redis" - "github.com/ulule/limiter" -) - -func main() { - // 4 reqs/hour - rate, err := limiter.NewRateFromFormatted("4-H") - if err != nil { - panic(err) - } - - // Create a Redis pool. - pool := redis.NewPool(func() (redis.Conn, error) { - c, err := redis.Dial("tcp", ":6379") - if err != nil { - return nil, err - } - return c, err - }, 100) - - // Create a store with the pool. - store, err := limiter.NewRedisStoreWithOptions( - pool, - limiter.StoreOptions{Prefix: "limiter_gjr_example", MaxRetry: 3}) - - if err != nil { - panic(err) - } - - // Create API. - api := rest.NewApi() - api.Use(rest.DefaultDevStack...) - - // Add middleware with the limiter instance. - api.Use(limiter.NewGJRMiddleware(limiter.NewLimiter(store, rate))) - - // Set stupid app. - api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { - w.WriteJson(map[string]string{"message": "ok"}) - })) - - // Run server! - fmt.Println("Server is running on 7777...") - log.Fatal(http.ListenAndServe(":7777", api.MakeHandler())) -} diff --git a/examples/http/main.go b/examples/http/main.go index 6fa3497..a9e2613 100644 --- a/examples/http/main.go +++ b/examples/http/main.go @@ -5,44 +5,50 @@ import ( "log" "net/http" - "github.com/garyburd/redigo/redis" + redis "github.com/go-redis/redis" "github.com/ulule/limiter" + "github.com/ulule/limiter/drivers/middleware/stdlib" + sredis "github.com/ulule/limiter/drivers/store/redis" ) func main() { - // 4 reqs/hour + + // Define a limit rate to 4 requests per hour. rate, err := limiter.NewRateFromFormatted("4-H") if err != nil { - panic(err) + log.Fatal(err) + return } - // Create a Redis pool. - pool := redis.NewPool(func() (redis.Conn, error) { - c, err := redis.Dial("tcp", ":6379") - if err != nil { - return nil, err - } - return c, err - }, 100) - - // Create a store with the pool. - store, err := limiter.NewRedisStoreWithOptions( - pool, - limiter.StoreOptions{Prefix: "limiter_http_example", MaxRetry: 3}) + // Create a redis client. + option, err := redis.ParseURL("redis://localhost:6379/0") + if err != nil { + log.Fatal(err) + return + } + client := redis.NewClient(option) + // Create a store with the redis client. + store, err := sredis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter_http_example", + MaxRetry: 3, + }) if err != nil { - panic(err) + log.Fatal(err) + return } - mw := limiter.NewHTTPMiddleware(limiter.NewLimiter(store, rate)) - http.Handle("/", mw.Handler(http.HandlerFunc(index))) + // Create a new middleware with the limiter instance. + middleware := stdlib.NewMiddleware(limiter.New(store, rate)) + // Launch a simple server. + http.Handle("/", middleware.Handler(http.HandlerFunc(index))) fmt.Println("Server is running on port 7777...") log.Fatal(http.ListenAndServe(":7777", nil)) } func index(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write([]byte(`{"message": "ok"}`)) } diff --git a/gometalinter.json b/gometalinter.json new file mode 100644 index 0000000..0e96552 --- /dev/null +++ b/gometalinter.json @@ -0,0 +1,32 @@ +{ + "DisableAll": true, + "Enable": [ + "lll", + "misspell", + "gofmt", + "dupl", + "ineffassign", + "errcheck", + "gas", + "vet", + "unconvert", + "interfacer", + "deadcode", + "gocyclo", + "golint", + "goconst", + "megacheck", + "varcheck", + "structcheck" + ], + "EnableGC": true, + "Deadline": "1200s", + "Concurrency": 1, + "Vendor": true, + "VendoredLinters": true, + "Aggregate": true, + "Test": true, + "LineLength": 120, + "Cyclo": 10, + "DuplThreshold": 80 +} diff --git a/limiter.go b/limiter.go index f135c09..6e0367a 100644 --- a/limiter.go +++ b/limiter.go @@ -1,5 +1,9 @@ package limiter +import ( + "context" +) + // ----------------------------------------------------------------- // Context // ----------------------------------------------------------------- @@ -22,20 +26,20 @@ type Limiter struct { Rate Rate } -// NewLimiter returns an instance of Limiter. -func NewLimiter(store Store, rate Rate) *Limiter { +// New returns an instance of Limiter. +func New(store Store, rate Rate) *Limiter { return &Limiter{ Store: store, Rate: rate, } } -// Get returns the limit for the identifier. -func (l *Limiter) Get(key string) (Context, error) { - return l.Store.Get(key, l.Rate) +// Get returns the limit for given identifier. +func (limiter *Limiter) Get(ctx context.Context, key string) (Context, error) { + return limiter.Store.Get(ctx, key, limiter.Rate) } -// Peek returns the limit for identifier without impacting accounting -func (l *Limiter) Peek(key string) (Context, error) { - return l.Store.Peek(key, l.Rate) +// Peek returns the limit for given identifier, without modification on current values. +func (limiter *Limiter) Peek(ctx context.Context, key string) (Context, error) { + return limiter.Store.Peek(ctx, key, limiter.Rate) } diff --git a/limiter_test.go b/limiter_test.go deleted file mode 100644 index 83cfb2e..0000000 --- a/limiter_test.go +++ /dev/null @@ -1,118 +0,0 @@ -package limiter - -import ( - "math" - "math/rand" - "testing" - "time" - - "github.com/garyburd/redigo/redis" - "github.com/stretchr/testify/assert" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - -func RandStringRunes(n int) string { - b := make([]rune, n) - for i := range b { - b[i] = letterRunes[rand.Intn(len(letterRunes))] - } - return string(b) -} - -// TestLimiterMemory tests Limiter with memory store. -func TestLimiterMemory(t *testing.T) { - rate, err := NewRateFromFormatted("3-M") - assert.Nil(t, err) - - store := NewMemoryStoreWithOptions(StoreOptions{ - Prefix: "limitertests:memory", - CleanUpInterval: 30 * time.Second, - }) - - testLimiter(t, store, rate) -} - -// TestLimiterRedis tests Limiter with Redis store. -func TestLimiterRedis(t *testing.T) { - rate, err := NewRateFromFormatted("3-M") - assert.Nil(t, err) - - randPrefix := RandStringRunes(10) - store, err := NewRedisStoreWithOptions( - newRedisPool(), - StoreOptions{Prefix: "limitertests:redis_" + randPrefix, MaxRetry: 3}) - - assert.Nil(t, err) - - testLimiter(t, store, rate) -} - -func testLimiter(t *testing.T, store Store, rate Rate) { - limiter := NewLimiter(store, rate) - - i := 1 - for i <= 5 { - if i <= 3 { - ctx, err := limiter.Peek("boo") - assert.NoError(t, err) - assert.Equal(t, int64(3-(i-1)), ctx.Remaining) - } - - ctx, err := limiter.Get("boo") - assert.NoError(t, err) - - if i <= 3 { - assert.Equal(t, int64(3), ctx.Limit) - assert.Equal(t, int64(3-i), ctx.Remaining) - assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60) - ctx, err := limiter.Peek("boo") - assert.NoError(t, err) - assert.Equal(t, int64(3-i), ctx.Remaining) - } else { - assert.Equal(t, int64(3), ctx.Limit) - assert.True(t, ctx.Remaining == 0) - assert.True(t, math.Ceil(time.Since(time.Unix(ctx.Reset, 0)).Seconds()) <= 60) - } - - i++ - } - -} - -// ----------------------------------------------------------------------------- -// Helpers -// ----------------------------------------------------------------------------- - -// newRedisPool returns -func newRedisPool() *redis.Pool { - return redis.NewPool(func() (redis.Conn, error) { - c, err := redis.Dial("tcp", ":6379") - if err != nil { - return nil, err - } - return c, err - }, 100) -} - -// newRedisLimiter returns an instance of limiter with redis backend. -func newRedisLimiter(formattedQuota string, prefix string) *Limiter { - rate, err := NewRateFromFormatted(formattedQuota) - if err != nil { - panic(err) - } - - store, err := NewRedisStoreWithOptions( - newRedisPool(), - StoreOptions{Prefix: prefix, MaxRetry: 3}) - - if err != nil { - panic(err) - } - - return NewLimiter(store, rate) -} diff --git a/middleware_gjr.go b/middleware_gjr.go deleted file mode 100644 index 400ce20..0000000 --- a/middleware_gjr.go +++ /dev/null @@ -1,45 +0,0 @@ -package limiter - -import ( - "strconv" - - "github.com/ant0ine/go-json-rest/rest" -) - -// GJRMiddleware is the go-json-rest middleware. -type GJRMiddleware struct { - Limiter *Limiter -} - -// NewGJRMiddleware returns a new instance of go-json-rest middleware. -func NewGJRMiddleware(limiter *Limiter) *GJRMiddleware { - return &GJRMiddleware{ - Limiter: limiter, - } -} - -// MiddlewareFunc is the middleware method (handler). -func (m *GJRMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFunc { - return func(w rest.ResponseWriter, r *rest.Request) { - context, err := m.Limiter.Get(GetIPKey(r.Request)) - if err != nil { - panic(err) - } - - w.Header().Add("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) - w.Header().Add("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) - w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) - - // That can be useful to access rate limit context in views. - r.Env["ratelimit:limit"] = context.Limit - r.Env["ratelimit:remaining"] = context.Remaining - r.Env["ratelimit:reset"] = context.Reset - - if context.Reached { - rest.Error(w, "Limit exceeded", 429) - return - } - - h(w, r) - } -} diff --git a/middleware_gjr_test.go b/middleware_gjr_test.go deleted file mode 100644 index c06ab5a..0000000 --- a/middleware_gjr_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package limiter - -import ( - "fmt" - "math" - "runtime" - "sync" - "testing" - "time" - - "github.com/ant0ine/go-json-rest/rest" - "github.com/ant0ine/go-json-rest/rest/test" - "github.com/stretchr/testify/assert" -) - -// TestRate tests ratelimit.Rate methods. -func TestGJRMiddleware(t *testing.T) { - api := rest.NewApi() - - api.Use(NewGJRMiddleware(newRedisLimiter("10-M", "limitertests:gjr"))) - - var reset int64 - - api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { - reset = r.Env["ratelimit:reset"].(int64) - w.WriteJson(map[string]string{"message": "ok"}) - })) - - handler := api.MakeHandler() - req := test.MakeSimpleRequest("GET", "http://localhost/", nil) - req.RemoteAddr = fmt.Sprintf("178.1.2.%d:120", Random(1, 90)) - - i := 1 - for i < 20 { - recorded := test.RunRequest(t, handler, req) - assert.True(t, math.Ceil(time.Since(time.Unix(reset, 0)).Seconds()) <= 60) - if i <= 10 { - recorded.BodyIs(`{"message":"ok"}`) - recorded.HeaderIs("X-Ratelimit-Limit", "10") - recorded.HeaderIs("X-Ratelimit-Remaining", fmt.Sprintf("%d", 10-i)) - recorded.CodeIs(200) - } else { - recorded.BodyIs(`{"Error":"Limit exceeded"}`) - recorded.HeaderIs("X-Ratelimit-Limit", "10") - recorded.HeaderIs("X-Ratelimit-Remaining", "0") - recorded.CodeIs(429) - } - i++ - } -} - -// TestGJRMiddlewareWithRaceCondition test GRJ middleware under race condition. -func TestGJRMiddlewareWithRaceCondition(t *testing.T) { - runtime.GOMAXPROCS(4) - - api := rest.NewApi() - - api.Use(NewGJRMiddleware(newRedisLimiter("5-M", "limitertests:gjrrace"))) - - api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) { - w.WriteJson(map[string]string{"message": "ok"}) - })) - - handler := api.MakeHandler() - req := test.MakeSimpleRequest("GET", "http://localhost/", nil) - req.RemoteAddr = fmt.Sprintf("178.1.2.%d:180", Random(1, 90)) - - nbRequests := 100 - successCount := 0 - - var wg sync.WaitGroup - wg.Add(nbRequests) - - for i := 1; i <= nbRequests; i++ { - go func() { - recorded := test.RunRequest(t, handler, req) - if recorded.Recorder.Code == 200 { - successCount++ - } - wg.Done() - }() - } - - wg.Wait() - - assert.Equal(t, 5, successCount) -} diff --git a/middleware_http.go b/middleware_http.go deleted file mode 100644 index e327900..0000000 --- a/middleware_http.go +++ /dev/null @@ -1,38 +0,0 @@ -package limiter - -// HTTPMiddleware is the middleware for basic http.Handler. -import ( - "net/http" - "strconv" -) - -// HTTPMiddleware is the basic HTTP middleware. -type HTTPMiddleware struct { - Limiter *Limiter -} - -// NewHTTPMiddleware return a new instance of a basic HTTP middleware. -func NewHTTPMiddleware(limiter *Limiter) *HTTPMiddleware { - return &HTTPMiddleware{Limiter: limiter} -} - -// Handler the middleware handler. -func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - context, err := m.Limiter.Get(GetIPKey(r)) - if err != nil { - panic(err) - } - - w.Header().Add("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) - w.Header().Add("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) - w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) - - if context.Reached { - http.Error(w, "Limit exceeded", 429) - return - } - - h.ServeHTTP(w, r) - }) -} diff --git a/middleware_http_test.go b/middleware_http_test.go deleted file mode 100644 index c9cc3b9..0000000 --- a/middleware_http_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package limiter - -import ( - "fmt" - "net/http" - "net/http/httptest" - "runtime" - "sync" - "testing" - - "github.com/stretchr/testify/assert" -) - -// TestHTTPMiddleware tests the HTTP middleware. -func TestHTTPMiddleware(t *testing.T) { - req, _ := http.NewRequest("GET", "/", nil) - req.RemoteAddr = fmt.Sprintf("178.1.2.%d:100", Random(1, 90)) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) - }) - - mw := NewHTTPMiddleware(newRedisLimiter("5-M", "limitertests:http")).Handler(handler) - - i := 1 - for i <= 10 { - res := httptest.NewRecorder() - mw.ServeHTTP(res, req) - if i <= 5 { - assert.Equal(t, res.Code, 200) - } else { - assert.Equal(t, res.Code, 429) - } - i++ - } -} - -// TestHTTPMiddlewareWithRaceCondition tests the HTTP middleware under race condition. -func TestHTTPMiddlewareWithRaceCondition(t *testing.T) { - runtime.GOMAXPROCS(4) - - req, _ := http.NewRequest("GET", "/", nil) - req.RemoteAddr = fmt.Sprintf("178.1.2.%d:110", Random(1, 90)) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) - }) - - mw := NewHTTPMiddleware(newRedisLimiter("11-M", "limitertests:httprace")).Handler(handler) - - nbRequests := 100 - successCount := 0 - - var wg sync.WaitGroup - wg.Add(nbRequests) - - for i := 1; i <= nbRequests; i++ { - go func() { - res := httptest.NewRecorder() - mw.ServeHTTP(res, req) - if res.Code == 200 { - successCount++ - } - wg.Done() - }() - } - - wg.Wait() - - assert.Equal(t, 11, successCount) -} diff --git a/rate.go b/rate.go index 5b22409..f3f917c 100644 --- a/rate.go +++ b/rate.go @@ -1,10 +1,11 @@ package limiter import ( - "fmt" "strconv" "strings" "time" + + "github.com/pkg/errors" ) // Rate is the rate. @@ -20,7 +21,7 @@ func NewRateFromFormatted(formatted string) (Rate, error) { values := strings.Split(formatted, "-") if len(values) != 2 { - return rate, fmt.Errorf("Incorrect format '%s'", formatted) + return rate, errors.Errorf("incorrect format '%s'", formatted) } periods := map[string]time.Duration{ @@ -32,26 +33,21 @@ func NewRateFromFormatted(formatted string) (Rate, error) { limit, period := values[0], strings.ToUpper(values[1]) duration, ok := periods[period] - if !ok { - return rate, fmt.Errorf("Incorrect period '%s'", period) + return rate, errors.Errorf("incorrect period '%s'", period) } - var ( - p time.Duration - l int - ) - - p = 1 * duration - - l, err := strconv.Atoi(limit) + p := 1 * duration + l, err := strconv.ParseInt(limit, 10, 64) if err != nil { - return rate, fmt.Errorf("Incorrect limit '%s'", limit) + return rate, errors.Errorf("incorrect limit '%s'", limit) } - return Rate{ + rate = Rate{ Formatted: formatted, Period: p, - Limit: int64(l), - }, nil + Limit: l, + } + + return rate, nil } diff --git a/rate_test.go b/rate_test.go index 773f7b2..3078570 100644 --- a/rate_test.go +++ b/rate_test.go @@ -5,23 +5,25 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestRate tests Rate methods. func TestRate(t *testing.T) { + is := require.New(t) + expected := map[string]Rate{ - "10-S": Rate{ + "10-S": { Formatted: "10-S", Period: 1 * time.Second, Limit: int64(10), }, - "356-M": Rate{ + "356-M": { Formatted: "356-M", Period: 1 * time.Minute, Limit: int64(356), }, - "3-H": Rate{ + "3-H": { Formatted: "3-H", Period: 1 * time.Hour, Limit: int64(3), @@ -30,8 +32,8 @@ func TestRate(t *testing.T) { for k, v := range expected { r, err := NewRateFromFormatted(k) - assert.Nil(t, err) - assert.True(t, reflect.DeepEqual(v, r)) + is.NoError(err) + is.True(reflect.DeepEqual(v, r)) } wrongs := []string{ @@ -44,7 +46,7 @@ func TestRate(t *testing.T) { for _, w := range wrongs { _, err := NewRateFromFormatted(w) - assert.NotNil(t, err) + is.Error(err) } } diff --git a/scripts/lint b/scripts/lint new file mode 100755 index 0000000..c26d822 --- /dev/null +++ b/scripts/lint @@ -0,0 +1,10 @@ +#!/bin/bash + +set -eo pipefail + +for package in $(go list ./... | grep -v -E '\/(vendor|examples)\/') +do + path="${GOPATH}/src/${package}" + echo "[gometalinter] ${package}" + gometalinter --config gometalinter.json --disable=dupl "${path}" +done diff --git a/scripts/redis b/scripts/redis new file mode 100755 index 0000000..00c1f83 --- /dev/null +++ b/scripts/redis @@ -0,0 +1,65 @@ +#!/bin/bash + +set -eo pipefail + +DOCKER_REDIS_PORT=${DOCKER_REDIS_PORT:-26379} + +CONTAINER_NAME="limiter-redis" +CONTAINER_IMAGE="redis:3.2" + +do_start() { + + if [[ -n "$(docker ps -q -f name="${CONTAINER_NAME}" 2> /dev/null)" ]]; then + echo "[redis] ${CONTAINER_NAME} already started. (use --restart otherwise)" + return 0 + fi + + if [[ -n "$(docker ps -a -q -f name="${CONTAINER_NAME}" 2> /dev/null)" ]]; then + echo "[redis] erase previous configuration" + docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true + docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true + fi + + echo "[redis] update redis images" + docker pull ${CONTAINER_IMAGE} || true + + echo "[redis] start new ${CONTAINER_NAME} container" + docker run --name "${CONTAINER_NAME}" \ + -p ${DOCKER_REDIS_PORT}:6379 \ + -d ${CONTAINER_IMAGE} >/dev/null + +} + +do_stop() { + + echo "[redis] stop ${CONTAINER_NAME} container" + docker stop "${CONTAINER_NAME}" >/dev/null 2>&1 || true + docker rm "${CONTAINER_NAME}" >/dev/null 2>&1 || true + +} + +do_client() { + + echo "[redis] use redis-cli on ${CONTAINER_NAME}" + docker run --rm -it \ + --link "${CONTAINER_NAME}":redis \ + ${CONTAINER_IMAGE} redis-cli -h redis -p 6379 -n 1 + +} + +case "$1" in + --stop) + do_stop + ;; + --restart) + do_stop + do_start + ;; + --client) + do_client + ;; + --start | *) + do_start + ;; +esac +exit 0 diff --git a/scripts/test b/scripts/test new file mode 100755 index 0000000..8c21886 --- /dev/null +++ b/scripts/test @@ -0,0 +1,11 @@ +#!/bin/bash + +SOURCE_DIRECTORY=$(dirname "${BASH_SOURCE[0]}") +cd "${SOURCE_DIRECTORY}/.." + +if [ -z "$REDIS_DISABLE_BOOTSTRAP" ]; then + export REDIS_URI="redis://localhost:26379/1" + scripts/redis --restart +fi + +go test -race -v $(go list ./... | grep -v -E '\/(vendor|examples)\/') diff --git a/store.go b/store.go index 9ad6a74..890e84d 100644 --- a/store.go +++ b/store.go @@ -1,11 +1,16 @@ package limiter -import "time" +import ( + "context" + "time" +) // Store is the common interface for limiter stores. type Store interface { - Get(key string, rate Rate) (Context, error) - Peek(key string, rate Rate) (Context, error) + // Get returns the limit for given identifier. + Get(ctx context.Context, key string, rate Rate) (Context, error) + // Peek returns the limit for given identifier, without modification on current values. + Peek(ctx context.Context, key string, rate Rate) (Context, error) } // StoreOptions are options for store. diff --git a/store_memory.go b/store_memory.go deleted file mode 100644 index 15ea180..0000000 --- a/store_memory.go +++ /dev/null @@ -1,99 +0,0 @@ -package limiter - -import ( - "fmt" - "time" - - cache "github.com/patrickmn/go-cache" -) - -// MemoryStore is the in-memory store. -type MemoryStore struct { - Cache *cache.Cache - Prefix string -} - -// NewMemoryStore creates a new instance of memory store with defaults. -func NewMemoryStore() Store { - return NewMemoryStoreWithOptions(StoreOptions{ - Prefix: DefaultPrefix, - CleanUpInterval: DefaultCleanUpInterval, - }) -} - -// NewMemoryStoreWithOptions creates a new instance of memory store with options. -func NewMemoryStoreWithOptions(options StoreOptions) Store { - return &MemoryStore{ - Prefix: options.Prefix, - Cache: cache.New(cache.NoExpiration, options.CleanUpInterval), - } -} - -// Get implement Store.Get() method. -func (s *MemoryStore) Get(key string, rate Rate) (Context, error) { - ctx := Context{} - key = fmt.Sprintf("%s:%s", s.Prefix, key) - item, found := s.Cache.Items()[key] - ms := int64(time.Millisecond) - now := time.Now() - - if !found || item.Expired() { - s.Cache.Set(key, int64(1), rate.Period) - - return Context{ - Limit: rate.Limit, - Remaining: rate.Limit - 1, - Reset: (now.UnixNano()/ms + int64(rate.Period)/ms) / 1000, - Reached: false, - }, nil - } - - count, err := s.Cache.IncrementInt64(key, 1) - if err != nil { - return ctx, err - } - - return s.getContextFromState(now, rate, item.Expiration, count), nil -} - -// Peek implement Store.Peek() method. -func (s *MemoryStore) Peek(key string, rate Rate) (Context, error) { - ctx := Context{} - key = fmt.Sprintf("%s:%s", s.Prefix, key) - item, found := s.Cache.Items()[key] - ms := int64(time.Millisecond) - now := time.Now() - - if !found || item.Expired() { - // new or expired should show what the values "would" be but not set cache state - return Context{ - Limit: rate.Limit, - Remaining: rate.Limit, - Reset: (now.UnixNano()/ms + int64(rate.Period)/ms) / 1000, - Reached: false, - }, nil - } - - count, ok := item.Object.(int64) - if !ok { - return ctx, fmt.Errorf("key=%s count not int64", key) - } - - return s.getContextFromState(now, rate, item.Expiration, count), nil -} - -func (s *MemoryStore) getContextFromState(now time.Time, rate Rate, expiration, count int64) Context { - remaining := int64(0) - if count < rate.Limit { - remaining = rate.Limit - count - } - - expire := time.Unix(0, expiration) - - return Context{ - Limit: rate.Limit, - Remaining: remaining, - Reset: expire.Add(time.Duration(expire.Sub(now).Seconds()) * time.Second).Unix(), - Reached: count > rate.Limit, - } -} diff --git a/store_redis.go b/store_redis.go deleted file mode 100644 index 038dc37..0000000 --- a/store_redis.go +++ /dev/null @@ -1,199 +0,0 @@ -package limiter - -import ( - "fmt" - "time" - - "github.com/garyburd/redigo/redis" -) - -// RedisStoreFunc is a redis store function. -type RedisStoreFunc func(c redis.Conn, key string, rate Rate) ([]int, error) - -// RedisStore is the redis store. -type RedisStore struct { - // The prefix to use for the key. - Prefix string - - // github.com/garyburd/redigo Pool instance. - Pool *redis.Pool - - // The maximum number of retry under race conditions. - MaxRetry int -} - -// NewRedisStore returns an instance of redis store. -func NewRedisStore(pool *redis.Pool) (Store, error) { - return NewRedisStoreWithOptions(pool, StoreOptions{ - Prefix: DefaultPrefix, - MaxRetry: DefaultMaxRetry, - }) -} - -// NewRedisStoreWithOptions returns an instance of redis store with custom options. -func NewRedisStoreWithOptions(pool *redis.Pool, options StoreOptions) (Store, error) { - store := &RedisStore{ - Pool: pool, - Prefix: options.Prefix, - MaxRetry: options.MaxRetry, - } - - if _, err := store.ping(); err != nil { - return nil, err - } - - return store, nil -} - -// ping checks if redis is alive. -func (s *RedisStore) ping() (bool, error) { - conn := s.Pool.Get() - defer conn.Close() - - data, err := conn.Do("PING") - if err != nil || data == nil { - return false, err - } - - return (data == "PONG"), nil -} - -func (s RedisStore) do(f RedisStoreFunc, c redis.Conn, key string, rate Rate) ([]int, error) { - for i := 1; i <= s.MaxRetry; i++ { - values, err := f(c, key, rate) - if err == nil && len(values) != 0 { - return values, nil - } - } - return nil, fmt.Errorf("retry limit exceeded") -} - -func (s RedisStore) setRate(c redis.Conn, key string, rate Rate) ([]int, error) { - c.Send("MULTI") - c.Send("SETNX", key, 1) - return redis.Ints(c.Do("EXEC")) -} - -func (s RedisStore) updateRate(c redis.Conn, key string, rate Rate) ([]int, error) { - c.Send("MULTI") - c.Send("INCR", key) - c.Send("TTL", key) - return redis.Ints(c.Do("EXEC")) -} - -func (s RedisStore) getRate(c redis.Conn, key string, rate Rate) ([]int, error) { - c.Send("MULTI") - c.Send("GET", key) - c.Send("TTL", key) - return redis.Ints(c.Do("EXEC")) -} - -// Get returns the limit for the identifier. -func (s RedisStore) Get(key string, rate Rate) (Context, error) { - var ( - err error - values []int - ) - - ctx := Context{} - key = fmt.Sprintf("%s:%s", s.Prefix, key) - - c := s.Pool.Get() - defer c.Close() - if err := c.Err(); err != nil { - return Context{}, err - } - - c.Do("WATCH", key) - defer c.Do("UNWATCH", key) - - values, err = s.do(s.setRate, c, key, rate) - if err != nil { - return ctx, err - } - - created := (values[0] == 1) - ms := int64(time.Millisecond) - - if created { - c.Do("EXPIRE", key, rate.Period.Seconds()) - return Context{ - Limit: rate.Limit, - Remaining: rate.Limit - 1, - Reset: (time.Now().UnixNano()/ms + int64(rate.Period)/ms) / 1000, - Reached: false, - }, nil - } - - values, err = s.do(s.updateRate, c, key, rate) - if err != nil { - return ctx, err - } - - count := int64(values[0]) - ttl := int64(values[1]) - remaining := int64(0) - - if count < rate.Limit { - remaining = rate.Limit - count - } - - return Context{ - Limit: rate.Limit, - Remaining: remaining, - Reset: time.Now().Add(time.Duration(ttl) * time.Second).Unix(), - Reached: count > rate.Limit, - }, nil -} - -// Peek returns the limit for the identifier. -func (s RedisStore) Peek(key string, rate Rate) (Context, error) { - var ( - err error - values []int - ) - - ctx := Context{} - key = fmt.Sprintf("%s:%s", s.Prefix, key) - - c := s.Pool.Get() - defer c.Close() - if err := c.Err(); err != nil { - return Context{}, err - } - - c.Do("WATCH", key) - defer c.Do("UNWATCH", key) - - values, err = s.do(s.getRate, c, key, rate) - if err != nil { - return ctx, err - } - - created := (values[0] == 0) - ms := int64(time.Millisecond) - - if created { - return Context{ - Limit: rate.Limit, - Remaining: rate.Limit, - Reset: (time.Now().UnixNano()/ms + int64(rate.Period)/ms) / 1000, - Reached: false, - }, nil - } - - count := int64(values[0]) - ttl := int64(values[1]) - remaining := int64(0) - - if count < rate.Limit { - remaining = rate.Limit - count - } - - return Context{ - Limit: rate.Limit, - Remaining: remaining, - Reset: time.Now().Add(time.Duration(ttl) * time.Second).Unix(), - Reached: count > rate.Limit, - }, nil -} diff --git a/utils.go b/utils.go index e36a5dc..461563e 100644 --- a/utils.go +++ b/utils.go @@ -10,15 +10,15 @@ import ( // GetIP returns IP address from request. func GetIP(r *http.Request) net.IP { - if ip := r.Header.Get("X-Forwarded-For"); ip != "" { + ip := r.Header.Get("X-Forwarded-For") + if ip != "" { parts := strings.Split(ip, ",") - for i, part := range parts { - parts[i] = strings.TrimSpace(part) - } - return net.ParseIP(parts[0]) + part := strings.TrimSpace(parts[0]) + return net.ParseIP(part) } - if ip := r.Header.Get("X-Real-IP"); ip != "" { + ip = r.Header.Get("X-Real-IP") + if ip != "" { return net.ParseIP(ip) } diff --git a/utils_test.go b/utils_test.go index 9286f61..bfecb52 100644 --- a/utils_test.go +++ b/utils_test.go @@ -1,60 +1,127 @@ -package limiter +package limiter_test import ( + "fmt" "net" "net/http" "net/url" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ulule/limiter" ) -// TestGetIP tests GetIP() function. func TestGetIP(t *testing.T) { - // - // RemoteAddr - // - - expected := net.ParseIP("8.8.8.8") + is := require.New(t) - r := http.Request{ + request1 := &http.Request{ URL: &url.URL{Path: "/"}, Header: http.Header{}, RemoteAddr: "8.8.8.8:8888", } - ip := GetIP(&r) - assert.Equal(t, expected, ip) - - // - // X-Forwarded-For - // - - expected = net.ParseIP("9.9.9.9") - - r = http.Request{ + request2 := &http.Request{ URL: &url.URL{Path: "/foo"}, Header: http.Header{}, RemoteAddr: "8.8.8.8:8888", } + request2.Header.Add("X-Forwarded-For", "9.9.9.9, 7.7.7.7, 6.6.6.6") + + request3 := &http.Request{ + URL: &url.URL{Path: "/bar"}, + Header: http.Header{}, + RemoteAddr: "8.8.8.8:8888", + } + request3.Header.Add("X-Real-IP", "6.6.6.6") + + scenarios := []struct { + request *http.Request + expected net.IP + }{ + { + // + // Scenario #1 : RemoteAddr + // + request: request1, + expected: net.ParseIP("8.8.8.8"), + }, + { + // + // Scenario #2 : X-Forwarded-For + // + request: request2, + expected: net.ParseIP("9.9.9.9"), + }, + { + // + // Scenario #3 : X-Real-IP + // + request: request3, + expected: net.ParseIP("6.6.6.6"), + }, + } + + for i, scenario := range scenarios { + message := fmt.Sprintf("Scenario #%d", (i + 1)) + ip := limiter.GetIP(scenario.request) + is.Equal(scenario.expected, ip, message) + } +} - r.Header.Add("X-Forwarded-For", "9.9.9.9, 7.7.7.7, 6.6.6.6") - ip = GetIP(&r) - assert.Equal(t, expected, ip) +func TestGetIPKey(t *testing.T) { + is := require.New(t) - // - // X-Real-IP - // + request1 := &http.Request{ + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + RemoteAddr: "8.8.8.8:8888", + } - expected = net.ParseIP("6.6.6.6") + request2 := &http.Request{ + URL: &url.URL{Path: "/foo"}, + Header: http.Header{}, + RemoteAddr: "8.8.8.8:8888", + } + request2.Header.Add("X-Forwarded-For", "9.9.9.9, 7.7.7.7, 6.6.6.6") - r = http.Request{ + request3 := &http.Request{ URL: &url.URL{Path: "/bar"}, Header: http.Header{}, RemoteAddr: "8.8.8.8:8888", } + request3.Header.Add("X-Real-IP", "6.6.6.6") - r.Header.Add("X-Real-IP", "6.6.6.6") - ip = GetIP(&r) - assert.Equal(t, expected, ip) + scenarios := []struct { + request *http.Request + expected string + }{ + { + // + // Scenario #1 : RemoteAddr + // + request: request1, + expected: "8.8.8.8", + }, + { + // + // Scenario #2 : X-Forwarded-For + // + request: request2, + expected: "9.9.9.9", + }, + { + // + // Scenario #3 : X-Real-IP + // + request: request3, + expected: "6.6.6.6", + }, + } + + for i, scenario := range scenarios { + message := fmt.Sprintf("Scenario #%d", (i + 1)) + key := limiter.GetIPKey(scenario.request) + is.Equal(scenario.expected, key, message) + } }