diff --git a/README.md b/README.md index a1485d9..a983617 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ _Dead simple rate limit middleware for Go._ Using [Go Modules](https://github.com/golang/go/wiki/Modules) ```bash -$ go get github.com/ulule/limiter/v3@v3.5.0 +$ go get github.com/ulule/limiter/v3@v3.7.1 ``` ## Usage @@ -79,7 +79,6 @@ import "github.com/ulule/limiter/v3/drivers/store/redis" store, err := redis.NewStoreWithOptions(pool, limiter.StoreOptions{ Prefix: "your_own_prefix", - MaxRetry: 4, }) if err != nil { panic(err) diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go index 97cd778..d3483dc 100644 --- a/drivers/store/redis/store.go +++ b/drivers/store/redis/store.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" libredis "github.com/go-redis/redis/v8" @@ -55,18 +56,19 @@ type Client interface { type Store struct { // Prefix used for the key. Prefix string - // deprecated, this option make no sense when all operations were atomic // MaxRetry is the maximum number of retry under race conditions. + // Deprecated: this option is no longer required since all operations are atomic now. MaxRetry int // client used to communicate with redis server. client Client - // luaIncrSHA is the SHA of increase and expire key script + // luaMutex is a mutex used to avoid concurrent access on luaIncrSHA and luaPeekSHA. + luaMutex sync.RWMutex + // luaLoaded is used for CAS and reduce pressure on luaMutex. + luaLoaded uint32 + // luaIncrSHA is the SHA of increase and expire key script. luaIncrSHA string - // luaPeekSHA is the SHA of peek and expire key script + // luaPeekSHA is the SHA of peek and expire key script. luaPeekSHA string - // hasLuaScriptLoaded was used to check whether the lua script was loaded or not - hasLuaScriptLoaded bool - mu sync.Mutex } // NewStore returns an instance of redis store with defaults. @@ -81,125 +83,173 @@ func NewStore(client Client) (limiter.Store, error) { // NewStoreWithOptions returns an instance of redis store with options. func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) { store := &Store{ - client: client, - Prefix: options.Prefix, - MaxRetry: options.MaxRetry, - hasLuaScriptLoaded: false, + client: client, + Prefix: options.Prefix, + MaxRetry: options.MaxRetry, } - if store.MaxRetry <= 0 { - store.MaxRetry = 1 - } - if err := store.preloadLuaScripts(context.Background()); err != nil { + err := store.preloadLuaScripts(context.Background()) + if err != nil { return nil, err } - return store, nil -} -// preloadLuaScripts would preload the incr and peek lua script -func (store *Store) preloadLuaScripts(ctx context.Context) error { - store.mu.Lock() - defer store.mu.Unlock() - if store.hasLuaScriptLoaded { - return nil - } - incrLuaSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result() - if err != nil { - return errors.Wrap(err, "failed to load incr lua script") - } - peekLuaSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result() - if err != nil { - return errors.Wrap(err, "failed to load peek lua script") - } - store.luaIncrSHA = incrLuaSHA - store.luaPeekSHA = peekLuaSHA - store.hasLuaScriptLoaded = true - return nil + return store, nil } // Get returns the limit for given identifier. func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { key = fmt.Sprintf("%s:%s", store.Prefix, key) - cmd := store.evalSHA(ctx, store.luaIncrSHA, []string{key}, 1, rate.Period.Milliseconds()) + cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{key}, 1, rate.Period.Milliseconds()) count, ttl, err := parseCountAndTTL(cmd) if err != nil { return limiter.Context{}, err } + now := time.Now() expiration := now.Add(rate.Period) if ttl > 0 { expiration = now.Add(time.Duration(ttl) * time.Millisecond) } + return common.GetContextFromState(now, rate, expiration, count), nil } // Peek returns the limit for given identifier, without modification on current values. func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { key = fmt.Sprintf("%s:%s", store.Prefix, key) - cmd := store.evalSHA(ctx, store.luaPeekSHA, []string{key}) + cmd := store.evalSHA(ctx, store.getLuaPeekSHA, []string{key}) count, ttl, err := parseCountAndTTL(cmd) if err != nil { return limiter.Context{}, err } + now := time.Now() expiration := now.Add(rate.Period) if ttl > 0 { expiration = now.Add(time.Duration(ttl) * time.Millisecond) } + return common.GetContextFromState(now, rate, expiration, count), nil } // Reset returns the limit for given identifier which is set to zero. func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { key = fmt.Sprintf("%s:%s", store.Prefix, key) - if _, err := store.client.Del(ctx, key).Result(); err != nil { + _, err := store.client.Del(ctx, key).Result() + if err != nil { return limiter.Context{}, err } + count := int64(0) now := time.Now() expiration := now.Add(rate.Period) + return common.GetContextFromState(now, rate, expiration, count), nil } -// evalSHA eval the redis lua sha and load the script if missing -func (store *Store) evalSHA(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd { - cmd := store.client.EvalSha(ctx, sha, keys, args...) - if err := cmd.Err(); err != nil { - if !isLuaScriptGone(err) { - return cmd - } - store.mu.Lock() - store.hasLuaScriptLoaded = false - store.mu.Unlock() - if err := store.preloadLuaScripts(ctx); err != nil { - cmd = libredis.NewCmd(ctx) - cmd.SetErr(err) - return cmd - } - cmd = store.client.EvalSha(ctx, sha, keys) - } - return cmd +// preloadLuaScripts preloads the "incr" and "peek" lua scripts. +func (store *Store) preloadLuaScripts(ctx context.Context) error { + // Verify if we need to load lua scripts. + // Inspired by sync.Once. + if atomic.LoadUint32(&store.luaLoaded) == 0 { + return store.loadLuaScripts(ctx) + } + return nil +} + +// reloadLuaScripts forces a reload of "incr" and "peek" lua scripts. +func (store *Store) reloadLuaScripts(ctx context.Context) error { + // Reset lua scripts loaded state. + // Inspired by sync.Once. + atomic.StoreUint32(&store.luaLoaded, 0) + return store.loadLuaScripts(ctx) } -// isLuaScriptGone check whether the error was no script or no +// loadLuaScripts load "incr" and "peek" lua scripts. +// WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one. +func (store *Store) loadLuaScripts(ctx context.Context) error { + store.luaMutex.Lock() + defer store.luaMutex.Unlock() + + // Check if scripts are already loaded. + if atomic.LoadUint32(&store.luaLoaded) != 0 { + return nil + } + + luaIncrSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result() + if err != nil { + return errors.Wrap(err, `failed to load "incr" lua script`) + } + + luaPeekSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result() + if err != nil { + return errors.Wrap(err, `failed to load "peek" lua script`) + } + + store.luaIncrSHA = luaIncrSHA + store.luaPeekSHA = luaPeekSHA + + atomic.StoreUint32(&store.luaLoaded, 1) + + return nil +} + +// getLuaIncrSHA returns a "thread-safe" value for luaIncrSHA. +func (store *Store) getLuaIncrSHA() string { + store.luaMutex.RLock() + defer store.luaMutex.RUnlock() + return store.luaIncrSHA +} + +// getLuaPeekSHA returns a "thread-safe" value for luaPeekSHA. +func (store *Store) getLuaPeekSHA() string { + store.luaMutex.RLock() + defer store.luaMutex.RUnlock() + return store.luaPeekSHA +} + +// evalSHA eval the redis lua sha and load the scripts if missing. +func (store *Store) evalSHA(ctx context.Context, getSha func() string, + keys []string, args ...interface{}) *libredis.Cmd { + + cmd := store.client.EvalSha(ctx, getSha(), keys, args...) + err := cmd.Err() + if err == nil || !isLuaScriptGone(err) { + return cmd + } + + err = store.reloadLuaScripts(ctx) + if err != nil { + cmd = libredis.NewCmd(ctx) + cmd.SetErr(err) + return cmd + } + + return store.client.EvalSha(ctx, getSha(), keys, args...) +} + +// isLuaScriptGone returns if the error is a missing lua script from redis server. func isLuaScriptGone(err error) bool { return strings.HasPrefix(err.Error(), "NOSCRIPT") } -// parseCountAndTTL parse count and ttl from lua script output +// parseCountAndTTL parse count and ttl from lua script output. func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) { - ret, err := cmd.Result() + result, err := cmd.Result() if err != nil { - return 0, 0, err + return 0, 0, errors.Wrap(err, "an error has occurred with redis command") } - if fields, ok := ret.([]interface{}); !ok || len(fields) != 2 { - return 0, 0, errors.New("two elements in array was expected") + + fields, ok := result.([]interface{}) + if !ok || len(fields) != 2 { + return 0, 0, errors.New("two elements in result were expected") } - fields := ret.([]interface{}) + count, ok1 := fields[0].(int64) ttl, ok2 := fields[1].(int64) if !ok1 || !ok2 { - return 0, 0, errors.New("type of the count and ttl should be number") + return 0, 0, errors.New("type of the count and/or ttl should be number") } + return count, ttl, nil } diff --git a/drivers/store/redis/store_test.go b/drivers/store/redis/store_test.go index 8a615dd..5673b35 100644 --- a/drivers/store/redis/store_test.go +++ b/drivers/store/redis/store_test.go @@ -22,8 +22,7 @@ func TestRedisStoreSequentialAccess(t *testing.T) { is.NotNil(client) store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ - Prefix: "limiter:redis:sequential", - MaxRetry: 3, + Prefix: "limiter:redis:sequential-test", }) is.NoError(err) is.NotNil(store) @@ -39,8 +38,7 @@ func TestRedisStoreConcurrentAccess(t *testing.T) { is.NotNil(client) store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ - Prefix: "limiter:redis:concurrent", - MaxRetry: 7, + Prefix: "limiter:redis:concurrent-test", }) is.NoError(err) is.NotNil(store) @@ -94,40 +92,49 @@ func TestRedisClientExpiration(t *testing.T) { is.Greater(actual, expected) } -func newRedisClient() (*libredis.Client, error) { - uri := "redis://localhost:6379/0" - if os.Getenv("REDIS_URI") != "" { - uri = os.Getenv("REDIS_URI") - } +func BenchmarkRedisStoreSequentialAccess(b *testing.B) { + is := require.New(b) - opt, err := libredis.ParseURL(uri) - if err != nil { - return nil, err - } + client, err := newRedisClient() + is.NoError(err) + is.NotNil(client) - client := libredis.NewClient(opt) - return client, nil + store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ + Prefix: "limiter:redis:sequential-benchmark", + }) + is.NoError(err) + is.NotNil(store) + + tests.BenchmarkStoreSequentialAccess(b, store) } -func BenchmarkGet(b *testing.B) { +func BenchmarkRedisStoreConcurrentAccess(b *testing.B) { is := require.New(b) + client, err := newRedisClient() is.NoError(err) is.NotNil(client) + store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{ - Prefix: "limiter:redis:benchmark", - MaxRetry: 3, + Prefix: "limiter:redis:concurrent-benchmark", }) is.NoError(err) is.NotNil(store) - limiter := limiter.New(store, limiter.Rate{ - Limit: 100000, - Period: 10 * time.Second, - }) - for i := 0; i < b.N; i++ { - lctx, err := limiter.Get(context.TODO(), "foo") - is.NoError(err) - is.NotZero(lctx) + tests.BenchmarkStoreConcurrentAccess(b, store) +} + +func newRedisClient() (*libredis.Client, error) { + uri := "redis://localhost:6379/0" + if os.Getenv("REDIS_URI") != "" { + uri = os.Getenv("REDIS_URI") + } + + opt, err := libredis.ParseURL(uri) + if err != nil { + return nil, err } + + client := libredis.NewClient(opt) + return client, nil } diff --git a/drivers/store/tests/tests.go b/drivers/store/tests/tests.go index a856e51..b08c086 100644 --- a/drivers/store/tests/tests.go +++ b/drivers/store/tests/tests.go @@ -142,3 +142,41 @@ func TestStoreConcurrentAccess(t *testing.T, store limiter.Store) { } wg.Wait() } + +// BenchmarkStoreSequentialAccess executes a benchmark against a store without parallel setting. +func BenchmarkStoreSequentialAccess(b *testing.B, store limiter.Store) { + is := require.New(b) + ctx := context.Background() + + limiter := limiter.New(store, limiter.Rate{ + Limit: 100000, + Period: 10 * time.Second, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + lctx, err := limiter.Get(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + } +} + +// BenchmarkStoreConcurrentAccess executes a benchmark against a store with parallel setting. +func BenchmarkStoreConcurrentAccess(b *testing.B, store limiter.Store) { + is := require.New(b) + ctx := context.Background() + + limiter := limiter.New(store, limiter.Rate{ + Limit: 100000, + Period: 10 * time.Second, + }) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + lctx, err := limiter.Get(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + } + }) +} diff --git a/go.sum b/go.sum index 17576a4..fe21c2e 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,7 @@ github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD87 github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= -github.com/go-redis/redis/v8 v8.3.2 h1:1bJscgN2yGtKLW6MsTRosa2LHyeq94j0hnNAgRZzj/M= -github.com/go-redis/redis/v8 v8.3.2/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU= +github.com/go-redis/redis/v8 v8.3.3 h1:e0CL9fsFDK92pkIJH2XAeS/NwO2VuIOAoJvI6yktZFk= github.com/go-redis/redis/v8 v8.3.3/go.mod h1:jszGxBCez8QA1HWSmQxJO9Y82kNibbUmeYhKWrBejTU= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -87,8 +86,7 @@ github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.16.0 h1:9zAqOYLl8Tuy3E5R6ckzGDJ1g8+pw15oQp2iL9Jl6gQ= -github.com/valyala/fasthttp v1.16.0/go.mod h1:YOKImeEosDdBPnxc0gy7INqi3m1zK6A+xl6TwOBhHCA= +github.com/valyala/fasthttp v1.17.0 h1:P8/koH4aSnJ4xbd0cUUFEGQs3jQqIxoDDyRQrUiAkqg= github.com/valyala/fasthttp v1.17.0/go.mod h1:jjraHZVbKOXftJfsOYoAjaeygpj5hr8ermTRJNroD7A= github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio= go.opentelemetry.io/otel v0.13.0 h1:2isEnyzjjJZq6r2EKMsFj4TxiQiexsM04AVhwbR/oBA= @@ -99,10 +97,9 @@ golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9 golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= -golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= -golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg= golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -118,8 +115,6 @@ golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980 h1:OjiUf46hAmXblsZdnoSXsEUSKU8r1UEzcL5RVZ4gO9Y= -golang.org/x/sys v0.0.0-20200602225109-6fdc65e7d980/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= diff --git a/store.go b/store.go index a9799d7..c797680 100644 --- a/store.go +++ b/store.go @@ -21,6 +21,7 @@ type StoreOptions struct { Prefix string // MaxRetry is the maximum number of retry under race conditions. + // Deprecated: this option is no longer required since all operations are atomic now. MaxRetry int // CleanUpInterval is the interval for cleanup.