diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a1f59b3..b8aa040 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,6 @@ jobs: strategy: matrix: go: - - 1.12.x - 1.13.x - 1.14.x - 1.15.x diff --git a/README.md b/README.md index 83ebc98..4d214c0 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ func myHandlerFunc(w http.ResponseWriter, r *http.Request) { } func main() { - store, err := memstore.New(65536) + store, err := memstore.NewCtx(65536) if err != nil { log.Fatal(err) } @@ -74,12 +74,12 @@ func main() { MaxRate: throttled.PerMin(20), MaxBurst: 5, } - rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) + rateLimiter, err := throttled.NewGCRARateLimiterCtx(store, quota) if err != nil { log.Fatal(err) } - httpRateLimiter := throttled.HTTPRateLimiter{ + httpRateLimiter := throttled.HTTPRateLimiterCtx{ RateLimiter: rateLimiter, VaryBy: &throttled.VaryBy{Path: true}, } @@ -89,6 +89,16 @@ func main() { } ``` +### Upgrading to `context.Context` aware version of `throttled` +To upgrade to the new `context.Context` aware version of `throttled`, update the package to the latest version and replace the following function with their context-aware equivalent: +- `memstore.New` => `memstore.NewCtx` +- `goredisstore.New` => `goredisstore.NewCtx` +- `redigostore.New` => `redigostore.NewCtx` +- `throttled.NewGCRARateLimiter` => `throttled.NewGCRARateLimiterCtx` +- `throttled.HTTPRateLimiter` => `throttled.HTTPRateLimiterCtx` + +Please note that not all stores make use of the passed `context.Context` yet. + ## Related Projects See [throttled/gcra][throttled-gcra] for a list of other projects related to diff --git a/deprecated.go b/deprecated.go index 8a2e61e..514115c 100644 --- a/deprecated.go +++ b/deprecated.go @@ -1,6 +1,7 @@ package throttled import ( + "context" "net/http" "time" ) @@ -65,7 +66,7 @@ func RateLimit(q Quota, vary *VaryBy, store GCRAStore) *Throttler { } rate := Rate{period: period / time.Duration(count)} - limiter, err := NewGCRARateLimiter(store, RateQuota{rate, count - 1}) + limiter, err := NewGCRARateLimiterCtx(WrapStoreWithContext(store), RateQuota{rate, count - 1}) // This panic in unavoidable because the original interface does // not support returning an error. @@ -87,3 +88,122 @@ func RateLimit(q Quota, vary *VaryBy, store GCRAStore) *Throttler { type Store interface { GCRAStore } + +// HTTPRateLimiter is an adapter for HTTPRateLimiterCtx to provide backwards +// compatibility. +// +// Deprecated: Use HTTPRateLimiterCtx instead. If the used RateLimiter does +// not implement RateLimiterCtx, wrap it with WrapRateLimiterWithContext(). +type HTTPRateLimiter struct { + // DeniedHandler is called if the request is disallowed. If it is + // nil, the DefaultDeniedHandler variable is used. + DeniedHandler http.Handler + + // Error is called if the RateLimiter returns an error. If it is + // nil, the DefaultErrorFunc is used. + Error func(w http.ResponseWriter, r *http.Request, err error) + + // Limiter is call for each request to determine whether the + // request is permitted and update internal state. It must be set. + RateLimiter RateLimiter + + // VaryBy is called for each request to generate a key for the + // limiter. If it is nil, all requests use an empty string key. + VaryBy interface { + Key(*http.Request) string + } +} + +// RateLimit provides an adapter for HTTPRateLimiterCtx.RateLimit. +// +// Deprecated: Use HTTPRateLimiterCtx instead +func (t *HTTPRateLimiter) RateLimit(h http.Handler) http.Handler { + l := HTTPRateLimiterCtx{ + DeniedHandler: t.DeniedHandler, + Error: t.Error, + RateLimiter: WrapRateLimiterWithContext(t.RateLimiter), + VaryBy: t.VaryBy, + } + return l.RateLimit(h) +} + +// GCRAStore is the version of GCRAStoreCtx that is not aware of context. +// +// Deprecated: Implement GCRAStoreCtx instead. +type GCRAStore interface { + GetWithTime(key string) (int64, time.Time, error) + SetIfNotExistsWithTTL(key string, value int64, ttl time.Duration) (bool, error) + CompareAndSwapWithTTL(key string, old, new int64, ttl time.Duration) (bool, error) +} + +// NewGCRARateLimiter is a backwards compatible adapter for NewGCRARateLimiterCtx. +// +// Deprecated: Use NewGCRARateLimiterCtx instead. If the used store does +// not implement GCRAStoreCtx, wrap it with WrapStoreWithContext(). +func NewGCRARateLimiter(st GCRAStore, quota RateQuota) (*GCRARateLimiterCtx, error) { + return NewGCRARateLimiterCtx(WrapStoreWithContext(st), quota) +} + +// A RateLimiter manages limiting the rate of actions by key. +// +// Deprecated: Use RateLimiterCtx instead. +type RateLimiter interface { + // RateLimit checks whether a particular key has exceeded a rate + // limit. It also returns a RateLimitResult to provide additional + // information about the state of the RateLimiter. + // + // If the rate limit has not been exceeded, the underlying storage + // is updated by the supplied quantity. For example, a quantity of + // 1 might be used to rate limit a single request while a greater + // quantity could rate limit based on the size of a file upload in + // megabytes. If quantity is 0, no update is performed allowing + // you to "peek" at the state of the RateLimiter for a given key. + RateLimit(key string, quantity int) (bool, RateLimitResult, error) +} + +// RateLimit is provided as a backwards compatible variant of RateLimitCtx. +// +// Deprecated: Use RateLimitCtx instead. +func (g *GCRARateLimiterCtx) RateLimit(key string, quantity int) (bool, RateLimitResult, error) { + return g.RateLimitCtx(context.Background(), key, quantity) +} + +// WrapStoreWithContext can be used to use GCRAStore in a place where a GCRAStoreCtx is required. +func WrapStoreWithContext(store GCRAStore) GCRAStoreCtx { + return gcraStoreCtxAdapter{ + gcraStore: store, + } +} + +// WrapRateLimiterWithContext can be used to use RateLimiter in a place where a RateLimiterCtx is required. +func WrapRateLimiterWithContext(rateLimier RateLimiter) RateLimiterCtx { + return rateLimiterCtxAdapter{ + rateLimiter: rateLimier, + } +} + +// gcraStoreCtxAdapter is an adapter that is used to use a GCRAStore where a GCRAStoreCtx is required. +type gcraStoreCtxAdapter struct { + gcraStore GCRAStore +} + +func (g gcraStoreCtxAdapter) GetWithTime(_ context.Context, key string) (int64, time.Time, error) { + return g.gcraStore.GetWithTime(key) +} + +func (g gcraStoreCtxAdapter) SetIfNotExistsWithTTL(_ context.Context, key string, value int64, ttl time.Duration) (bool, error) { + return g.gcraStore.SetIfNotExistsWithTTL(key, value, ttl) +} + +func (g gcraStoreCtxAdapter) CompareAndSwapWithTTL(_ context.Context, key string, old, new int64, ttl time.Duration) (bool, error) { + return g.gcraStore.CompareAndSwapWithTTL(key, old, new, ttl) +} + +// rateLimiterCtxAdapter is an adapter that is used to use a RateLimiter where a RateLimiterCtx is required. +type rateLimiterCtxAdapter struct { + rateLimiter RateLimiter +} + +func (r rateLimiterCtxAdapter) RateLimitCtx(_ context.Context, key string, quantity int) (bool, RateLimitResult, error) { + return r.rateLimiter.RateLimit(key, quantity) +} diff --git a/example_test.go b/example_test.go index f5784c9..2a17759 100644 --- a/example_test.go +++ b/example_test.go @@ -1,6 +1,7 @@ package throttled_test import ( + "context" "fmt" "log" "net/http" @@ -13,11 +14,11 @@ var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("hi there!")) }) -// ExampleHTTPRateLimiter demonstrates the usage of HTTPRateLimiter +// ExampleHTTPRateLimiter demonstrates the usage of HTTPRateLimiterCtx // for rate-limiting access to an http.Handler to 20 requests per path // per minute with a maximum burst of 5 requests. -func ExampleHTTPRateLimiter() { - store, err := memstore.New(65536) +func ExampleHTTPRateLimiterCtx() { + store, err := memstore.NewCtx(65536) if err != nil { log.Fatal(err) } @@ -25,12 +26,12 @@ func ExampleHTTPRateLimiter() { // Maximum burst of 5 which refills at 20 tokens per minute. quota := throttled.RateQuota{MaxRate: throttled.PerMin(20), MaxBurst: 5} - rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) + rateLimiter, err := throttled.NewGCRARateLimiterCtx(store, quota) if err != nil { log.Fatal(err) } - httpRateLimiter := throttled.HTTPRateLimiter{ + httpRateLimiter := throttled.HTTPRateLimiterCtx{ RateLimiter: rateLimiter, VaryBy: &throttled.VaryBy{Path: true}, } @@ -38,11 +39,11 @@ func ExampleHTTPRateLimiter() { http.ListenAndServe(":8080", httpRateLimiter.RateLimit(myHandler)) } -// Demonstrates direct use of GCRARateLimiter's RateLimit function (and the +// Demonstrates direct use of GCRARateLimiterCtx's RateLimit function (and the // more general RateLimiter interface). This should be used anywhere where // granular control over rate limiting is required. -func ExampleGCRARateLimiter() { - store, err := memstore.New(65536) +func ExampleGCRARateLimiterCtx() { + store, err := memstore.NewCtx(65536) if err != nil { log.Fatal(err) } @@ -50,7 +51,7 @@ func ExampleGCRARateLimiter() { // Maximum burst of 5 which refills at 1 token per hour. quota := throttled.RateQuota{MaxRate: throttled.PerHour(1), MaxBurst: 5} - rateLimiter, err := throttled.NewGCRARateLimiter(store, quota) + rateLimiter, err := throttled.NewGCRARateLimiterCtx(store, quota) if err != nil { log.Fatal(err) } @@ -65,7 +66,7 @@ func ExampleGCRARateLimiter() { for i := 0; i < 20; i++ { bucket := fmt.Sprintf("by-order:%v", i/10) - limited, result, err := rateLimiter.RateLimit(bucket, 1) + limited, result, err := rateLimiter.RateLimitCtx(context.Background(), bucket, 1) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index c9a669a..24be143 100644 --- a/go.mod +++ b/go.mod @@ -4,17 +4,11 @@ go 1.13 require ( github.com/go-redis/redis v6.15.8+incompatible - github.com/golang/protobuf v1.4.2 // indirect - github.com/gomodule/redigo v1.8.4 - github.com/google/go-cmp v0.5.0 // indirect + github.com/go-redis/redis/v8 v8.4.2 + github.com/gomodule/redigo v2.0.0+incompatible github.com/hashicorp/golang-lru v0.5.4 github.com/kr/pretty v0.1.0 // indirect - github.com/onsi/ginkgo v1.10.1 // indirect - github.com/onsi/gomega v1.7.0 // indirect - github.com/stretchr/testify v1.5.1 - golang.org/x/net v0.0.0-20190923162816-aa69164e4478 // indirect - golang.org/x/sys v0.0.0-20191010194322-b09406accb47 // indirect + github.com/stretchr/testify v1.6.1 golang.org/x/text v0.3.7 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect - gopkg.in/yaml.v2 v2.2.7 // indirect ) diff --git a/go.sum b/go.sum index 55a0ccb..e2619d8 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,18 @@ +github.com/cespare/xxhash/v2 v2.1.1 h1:6MnRN8NT7+YBpUIWxHtefFZOKTAPgGjpQSxqLNn0+qY= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-redis/redis v6.15.8+incompatible h1:BKZuG6mCnRj5AOaWJXoCgf6rqTYnYJLe4en2hxT7r9o= github.com/go-redis/redis v6.15.8+incompatible/go.mod h1:NAIEuMOZ/fxfXJIrKDQDz8wamY7mA7PouImQ2Jvg6kA= +github.com/go-redis/redis/v8 v8.4.2 h1:gKRo1KZ+O3kXRfxeRblV5Tr470d2YJZJVIAv2/S8960= +github.com/go-redis/redis/v8 v8.4.2/go.mod h1:A1tbYoHSa1fXwN+//ljcCYYJeLmVrwL9hbQN45Jdy0M= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -12,44 +21,62 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/gomodule/redigo v1.8.4 h1:Z5JUg94HMTR1XpwBaSH4vq3+PNSIykBLxMdglbw10gg= -github.com/gomodule/redigo v1.8.4/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0 h1:/QaMHBdZ26BB3SSst0Iwl10Epc+xhTquomWX0oZEB6w= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo= +github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= -github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/nxadm/tail v1.4.4 h1:DQuhQpB1tVlglWS2hLQ5OV6B5r8aGxSrPc5Qo6uTN78= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.10.1 h1:q/mM8GF/n0shIN8SaAZ0V+jnLPzen6WIVZdiwrRlMlo= -github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/gomega v1.7.0 h1:XPnZz8VVBHjVsy1vzJmRwIcSwiUO+JFfrv/xGiigmME= -github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.2 h1:8mVmC9kjFFmA8H4pKMUhcblgifdkOIXPvbhN1T36q1M= +github.com/onsi/ginkgo v1.14.2/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.10.3 h1:gph6h/qe9GSUw1NhH1gp+qb+h8rXD8Cy60Z32Qw3ELA= +github.com/onsi/gomega v1.10.3/go.mod h1:V9xEwhxec5O8UDM77eCW8vLymOMltsqPVYWrpDsH8xc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/otel v0.14.0 h1:YFBEfjCk9MTjaytCNSUkp9Q8lF7QJezA06T71FbQxLQ= +go.opentelemetry.io/otel v0.14.0/go.mod h1:vH5xEuwy7Rts0GNtsCW3HYQoZDY+OmBJ6t1bFGGlxgw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190923162816-aa69164e4478 h1:l5EDrHhldLYb3ZRHDUhXF7Om7MvYXnkV9/iQNo1lX6g= -golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= +golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY= -golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200519105757-fe76b779f299/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -65,11 +92,11 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= -gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http.go b/http.go index 4c513a8..fc6409d 100644 --- a/http.go +++ b/http.go @@ -22,8 +22,8 @@ var ( } ) -// HTTPRateLimiter faciliates using a Limiter to limit HTTP requests. -type HTTPRateLimiter struct { +// HTTPRateLimiterCtx faciliates using a Limiter to limit HTTP requests. +type HTTPRateLimiterCtx struct { // DeniedHandler is called if the request is disallowed. If it is // nil, the DefaultDeniedHandler variable is used. DeniedHandler http.Handler @@ -34,7 +34,7 @@ type HTTPRateLimiter struct { // Limiter is call for each request to determine whether the // request is permitted and update internal state. It must be set. - RateLimiter RateLimiter + RateLimiter RateLimiterCtx // VaryBy is called for each request to generate a key for the // limiter. If it is nil, all requests use an empty string key. @@ -49,7 +49,7 @@ type HTTPRateLimiter struct { // X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset and // Retry-After headers will be written to the response based on the // values in the RateLimitResult. -func (t *HTTPRateLimiter) RateLimit(h http.Handler) http.Handler { +func (t *HTTPRateLimiterCtx) RateLimit(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if t.RateLimiter == nil { t.error(w, r, errors.New("You must set a RateLimiter on HTTPRateLimiter")) @@ -60,7 +60,7 @@ func (t *HTTPRateLimiter) RateLimit(h http.Handler) http.Handler { k = t.VaryBy.Key(r) } - limited, context, err := t.RateLimiter.RateLimit(k, 1) + limited, context, err := t.RateLimiter.RateLimitCtx(r.Context(), k, 1) if err != nil { t.error(w, r, err) @@ -81,7 +81,7 @@ func (t *HTTPRateLimiter) RateLimit(h http.Handler) http.Handler { }) } -func (t *HTTPRateLimiter) error(w http.ResponseWriter, r *http.Request, err error) { +func (t *HTTPRateLimiterCtx) error(w http.ResponseWriter, r *http.Request, err error) { e := t.Error if e == nil { e = DefaultError diff --git a/http_test.go b/http_test.go index 2f11374..f8a3752 100644 --- a/http_test.go +++ b/http_test.go @@ -1,6 +1,7 @@ package throttled_test import ( + "context" "errors" "net/http" "net/http/httptest" @@ -13,7 +14,7 @@ import ( type stubLimiter struct { } -func (sl *stubLimiter) RateLimit(key string, quantity int) (bool, throttled.RateLimitResult, error) { +func (sl *stubLimiter) RateLimitCtx(_ context.Context, key string, quantity int) (bool, throttled.RateLimitResult, error) { switch key { case "limit": result := throttled.RateLimitResult{ @@ -50,7 +51,7 @@ type httpTestCase struct { } func TestHTTPRateLimiter(t *testing.T) { - limiter := throttled.HTTPRateLimiter{ + limiter := throttled.HTTPRateLimiterCtx{ RateLimiter: &stubLimiter{}, VaryBy: &pathGetter{}, } @@ -67,7 +68,7 @@ func TestHTTPRateLimiter(t *testing.T) { } func TestCustomHTTPRateLimiterHandlers(t *testing.T) { - limiter := throttled.HTTPRateLimiter{ + limiter := throttled.HTTPRateLimiterCtx{ RateLimiter: &stubLimiter{}, VaryBy: &pathGetter{}, DeniedHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -112,7 +113,7 @@ func runHTTPTestCases(t *testing.T, h http.Handler, cs []httpTestCase) { } func BenchmarkHTTPRateLimiter(b *testing.B) { - limiter := throttled.HTTPRateLimiter{ + limiter := throttled.HTTPRateLimiterCtx{ RateLimiter: &stubLimiter{}, VaryBy: &pathGetter{}, } diff --git a/rate.go b/rate.go index 12135eb..b54547e 100644 --- a/rate.go +++ b/rate.go @@ -1,6 +1,7 @@ package throttled import ( + "context" "fmt" "time" ) @@ -11,9 +12,9 @@ const ( maxCASAttempts = 10 ) -// A RateLimiter manages limiting the rate of actions by key. -type RateLimiter interface { - // RateLimit checks whether a particular key has exceeded a rate +// A RateLimiterCtx manages limiting the rate of actions by key. +type RateLimiterCtx interface { + // RateLimitCtx checks whether a particular key has exceeded a rate // limit. It also returns a RateLimitResult to provide additional // information about the state of the RateLimiter. // @@ -23,7 +24,7 @@ type RateLimiter interface { // quantity could rate limit based on the size of a file upload in // megabytes. If quantity is 0, no update is performed allowing // you to "peek" at the state of the RateLimiter for a given key. - RateLimit(key string, quantity int) (bool, RateLimitResult, error) + RateLimitCtx(ctx context.Context, key string, quantity int) (bool, RateLimitResult, error) } // RateLimitResult represents the state of the RateLimiter for a @@ -116,11 +117,11 @@ func PerDay(n int) Rate { return Rate{24 * time.Hour / time.Duration(n), n} } // PerDuration represents a number of requests per provided duration. func PerDuration(n int, d time.Duration) Rate { return Rate{d / time.Duration(n), n} } -// GCRARateLimiter is a RateLimiter that users the generic cell-rate +// GCRARateLimiterCtx is a RateLimiter that users the generic cell-rate // algorithm. The algorithm has been slightly modified from its usual // form to support limiting with an additional quantity parameter, such // as for limiting the number of bytes uploaded. -type GCRARateLimiter struct { +type GCRARateLimiterCtx struct { limit int // Think of the DVT as our flexibility: @@ -133,20 +134,20 @@ type GCRARateLimiter struct { // think of it as how frequently the bucket leaks one unit. emissionInterval time.Duration - store GCRAStore + store GCRAStoreCtx // Maximum number of times to retry SetIfNotExists/CompareAndSwap operations // before returning an error. maxCASAttemptsLimit int } -// NewGCRARateLimiter creates a GCRARateLimiter. quota.Count defines +// NewGCRARateLimiterCtx creates a GCRARateLimiterCtx. quota.Count defines // the maximum number of requests permitted in an instantaneous burst // and quota.Count / quota.Period defines the maximum sustained // rate. For example, PerMin(60) permits 60 requests instantly per key // followed by one request per second indefinitely whereas PerSec(1) // only permits one request per second with no tolerance for bursts. -func NewGCRARateLimiter(st GCRAStore, quota RateQuota) (*GCRARateLimiter, error) { +func NewGCRARateLimiterCtx(st GCRAStoreCtx, quota RateQuota) (*GCRARateLimiterCtx, error) { if quota.MaxBurst < 0 { return nil, fmt.Errorf("invalid RateQuota %#v; MaxBurst must be greater than zero", quota) } @@ -154,7 +155,7 @@ func NewGCRARateLimiter(st GCRAStore, quota RateQuota) (*GCRARateLimiter, error) return nil, fmt.Errorf("invalid RateQuota %#v; MaxRate must be greater than zero", quota) } - return &GCRARateLimiter{ + return &GCRARateLimiterCtx{ delayVariationTolerance: quota.MaxRate.period * (time.Duration(quota.MaxBurst) + 1), emissionInterval: quota.MaxRate.period, limit: quota.MaxBurst + 1, @@ -165,11 +166,11 @@ func NewGCRARateLimiter(st GCRAStore, quota RateQuota) (*GCRARateLimiter, error) // SetMaxCASAttemptsLimit allows you to set the maxCASAttempts limit. This is set to 10 // be default. -func (g *GCRARateLimiter) SetMaxCASAttemptsLimit(limit int) { +func (g *GCRARateLimiterCtx) SetMaxCASAttemptsLimit(limit int) { g.maxCASAttemptsLimit = limit } -// RateLimit checks whether a particular key has exceeded a rate +// RateLimitCtx checks whether a particular key has exceeded a rate // limit. It also returns a RateLimitResult to provide additional // information about the state of the RateLimiter. // @@ -179,7 +180,7 @@ func (g *GCRARateLimiter) SetMaxCASAttemptsLimit(limit int) { // quantity could rate limit based on the size of a file upload in // megabytes. If quantity is 0, no update is performed allowing you // to "peek" at the state of the RateLimiter for a given key. -func (g *GCRARateLimiter) RateLimit(key string, quantity int) (bool, RateLimitResult, error) { +func (g *GCRARateLimiterCtx) RateLimitCtx(ctx context.Context, key string, quantity int) (bool, RateLimitResult, error) { var tat, newTat, now time.Time var ttl time.Duration rlc := RateLimitResult{Limit: g.limit, RetryAfter: -1} @@ -193,7 +194,7 @@ func (g *GCRARateLimiter) RateLimit(key string, quantity int) (bool, RateLimitRe // tat refers to the theoretical arrival time that would be expected // from equally spaced requests at exactly the rate limit. - tatVal, now, err = g.store.GetWithTime(key) + tatVal, now, err = g.store.GetWithTime(ctx, key) if err != nil { return false, rlc, err } @@ -225,9 +226,9 @@ func (g *GCRARateLimiter) RateLimit(key string, quantity int) (bool, RateLimitRe ttl = newTat.Sub(now) if tatVal == -1 { - updated, err = g.store.SetIfNotExistsWithTTL(key, newTat.UnixNano(), ttl) + updated, err = g.store.SetIfNotExistsWithTTL(ctx, key, newTat.UnixNano(), ttl) } else { - updated, err = g.store.CompareAndSwapWithTTL(key, tatVal, newTat.UnixNano(), ttl) + updated, err = g.store.CompareAndSwapWithTTL(ctx, key, tatVal, newTat.UnixNano(), ttl) } if err != nil { diff --git a/rate_test.go b/rate_test.go index 84b37cd..7b61cd1 100644 --- a/rate_test.go +++ b/rate_test.go @@ -1,6 +1,7 @@ package throttled_test import ( + "context" "testing" "time" @@ -12,29 +13,29 @@ import ( const deniedStatus = 429 type testStore struct { - store throttled.GCRAStore + store throttled.GCRAStoreCtx clock time.Time failUpdates bool } -func (ts *testStore) GetWithTime(key string) (int64, time.Time, error) { - v, _, e := ts.store.GetWithTime(key) +func (ts *testStore) GetWithTime(ctx context.Context, key string) (int64, time.Time, error) { + v, _, e := ts.store.GetWithTime(ctx, key) return v, ts.clock, e } -func (ts *testStore) SetIfNotExistsWithTTL(key string, value int64, ttl time.Duration) (bool, error) { +func (ts *testStore) SetIfNotExistsWithTTL(ctx context.Context, key string, value int64, ttl time.Duration) (bool, error) { if ts.failUpdates { return false, nil } - return ts.store.SetIfNotExistsWithTTL(key, value, ttl) + return ts.store.SetIfNotExistsWithTTL(ctx, key, value, ttl) } -func (ts *testStore) CompareAndSwapWithTTL(key string, old, new int64, ttl time.Duration) (bool, error) { +func (ts *testStore) CompareAndSwapWithTTL(ctx context.Context, key string, old, new int64, ttl time.Duration) (bool, error) { if ts.failUpdates { return false, nil } - return ts.store.CompareAndSwapWithTTL(key, old, new, ttl) + return ts.store.CompareAndSwapWithTTL(ctx, key, old, new, ttl) } func TestRateLimit(t *testing.T) { @@ -71,13 +72,13 @@ func TestRateLimit(t *testing.T) { 15: {start.Add(15000 * time.Millisecond), 6, 5, 0, -1, true}, } - mst, err := memstore.New(0) + mst, err := memstore.NewCtx(0) if err != nil { t.Fatal(err) } st := testStore{store: mst} - rl, err := throttled.NewGCRARateLimiter(&st, rq) + rl, err := throttled.NewGCRARateLimiterCtx(&st, rq) if err != nil { t.Fatal(err) } @@ -86,7 +87,7 @@ func TestRateLimit(t *testing.T) { for i, c := range cases { st.clock = c.now - limited, context, err := rl.RateLimit("foo", c.volume) + limited, context, err := rl.RateLimitCtx(context.Background(), "foo", c.volume) if err != nil { t.Fatalf("%d: %#v", i, err) } @@ -116,19 +117,19 @@ func TestRateLimit(t *testing.T) { func TestRateLimitCustomPeriod(t *testing.T) { period := 10 * time.Millisecond rq := throttled.RateQuota{throttled.PerDuration(3, period), 0} - mst, err := memstore.New(27) + mst, err := memstore.NewCtx(27) if err != nil { t.Fatal(err) } st := testStore{store: mst} - rl, err := throttled.NewGCRARateLimiter(&st, rq) + rl, err := throttled.NewGCRARateLimiterCtx(&st, rq) if err != nil { t.Fatal(err) } for i := 0; i < 27; i++ { - limited, _, err := rl.RateLimit("bar", 1) + limited, _, err := rl.RateLimitCtx(context.Background(), "bar", 1) if err != nil { t.Fatal(err) } @@ -145,29 +146,29 @@ func TestRateLimitCustomPeriod(t *testing.T) { func TestRateLimitUpdateFailures(t *testing.T) { rq := throttled.RateQuota{MaxRate: throttled.PerSec(1), MaxBurst: 1} - mst, err := memstore.New(0) + mst, err := memstore.NewCtx(0) if err != nil { t.Fatal(err) } st := testStore{store: mst, failUpdates: true} - rl, err := throttled.NewGCRARateLimiter(&st, rq) + rl, err := throttled.NewGCRARateLimiterCtx(&st, rq) if err != nil { t.Fatal(err) } - if _, _, err := rl.RateLimit("foo", 1); err == nil { + if _, _, err := rl.RateLimitCtx(context.Background(), "foo", 1); err == nil { t.Error("Expected limiting to fail when store updates fail") } } func TestRateLimitUpdateFailuresWithRetryLimitSetToTwo(t *testing.T) { rq := throttled.RateQuota{MaxRate: throttled.PerSec(1), MaxBurst: 1} - mst, err := memstore.New(0) + mst, err := memstore.NewCtx(0) if err != nil { t.Fatal(err) } st := testStore{store: mst, failUpdates: true} - rl, err := throttled.NewGCRARateLimiter(&st, rq) + rl, err := throttled.NewGCRARateLimiterCtx(&st, rq) if err != nil { t.Fatal(err) } @@ -184,19 +185,19 @@ func TestRateLimitUpdateFailuresWithRetryLimitSetToTwo(t *testing.T) { func BenchmarkRateLimit(b *testing.B) { limit := 5 rq := throttled.RateQuota{MaxRate: throttled.PerSec(1000), MaxBurst: limit - 1} - mst, err := memstore.New(0) + mst, err := memstore.NewCtx(0) if err != nil { b.Fatal(err) } st := testStore{store: mst} - rl, err := throttled.NewGCRARateLimiter(&st, rq) + rl, err := throttled.NewGCRARateLimiterCtx(&st, rq) if err != nil { b.Fatal(err) } b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err = rl.RateLimit("foo", 1) + _, _, err = rl.RateLimitCtx(context.Background(), "foo", 1) } _ = err } diff --git a/store.go b/store.go index a26bbc2..77407e3 100644 --- a/store.go +++ b/store.go @@ -1,12 +1,13 @@ package throttled import ( + "context" "time" ) -// GCRAStore is the interface to implement to store state for a GCRA -// rate limiter -type GCRAStore interface { +// GCRAStoreCtx is the interface to implement to store state for a GCRA +// rate limiter that uses a context.Context +type GCRAStoreCtx interface { // GetWithTime returns the value of the key if it is in the store // or -1 if it does not exist. It also returns the current time at // the Store. The time must be representable as a positive int64 @@ -16,13 +17,13 @@ type GCRAStore interface { // share the same clock. Using separate clocks will work if the // skew is small but not recommended in practice unless you're // lucky enough to be hooked up to GPS or atomic clocks. - GetWithTime(key string) (int64, time.Time, error) + GetWithTime(ctx context.Context, key string) (int64, time.Time, error) // SetIfNotExistsWithTTL sets the value of key only if it is not // already set in the store it returns whether a new value was // set. If the store supports expiring keys and a new value was // set, the key will expire after the provided ttl. - SetIfNotExistsWithTTL(key string, value int64, ttl time.Duration) (bool, error) + SetIfNotExistsWithTTL(ctx context.Context, key string, value int64, ttl time.Duration) (bool, error) // CompareAndSwapWithTTL atomically compares the value at key to // the old value. If it matches, it sets it to the new value and @@ -30,5 +31,5 @@ type GCRAStore interface { // exist in the store, it returns false with no error. If the // store supports expiring keys and the swap succeeded, the key // will expire after the provided ttl. - CompareAndSwapWithTTL(key string, old, new int64, ttl time.Duration) (bool, error) + CompareAndSwapWithTTL(ctx context.Context, key string, old, new int64, ttl time.Duration) (bool, error) } diff --git a/store/deprecated.go b/store/deprecated.go index 2351d37..f3aad1b 100644 --- a/store/deprecated.go +++ b/store/deprecated.go @@ -3,7 +3,7 @@ package store // import "github.com/throttled/throttled/v2/store" import ( "github.com/gomodule/redigo/redis" - + "github.com/throttled/throttled/v2" "github.com/throttled/throttled/v2/store/memstore" "github.com/throttled/throttled/v2/store/redigostore" ) @@ -11,7 +11,7 @@ import ( // NewMemStore initializes a new memory-based store. // // Deprecated: Use github.com/throttled/throttled/v2/store/memstore instead. -func NewMemStore(maxKeys int) *memstore.MemStore { +func NewMemStore(maxKeys int) throttled.Store { st, err := memstore.New(maxKeys) if err != nil { // As of this writing, `lru.New` can only return an error if you pass @@ -24,7 +24,7 @@ func NewMemStore(maxKeys int) *memstore.MemStore { // NewRedisStore initializes a new Redigo-based store. // // Deprecated: Use github.com/throttled/throttled/v2/store/redigostore instead. -func NewRedisStore(pool *redis.Pool, keyPrefix string, db int) *redigostore.RedigoStore { +func NewRedisStore(pool *redis.Pool, keyPrefix string, db int) throttled.Store { st, err := redigostore.New(pool, keyPrefix, db) if err != nil { // As of this writing, creating a Redis store never returns an error diff --git a/store/goredisstore.v8/goredisstore.go b/store/goredisstore.v8/goredisstore.go new file mode 100644 index 0000000..4a773a6 --- /dev/null +++ b/store/goredisstore.v8/goredisstore.go @@ -0,0 +1,128 @@ +// Package goredisstore offers Redis-based store implementation for throttled using v8 of go-redis. +package goredisstore // import "github.com/throttled/throttled/v2/store/goredisstore" + +import ( + "context" + "strings" + "time" + + "github.com/go-redis/redis/v8" +) + +const ( + redisCASMissingKey = "key does not exist" + redisCASScript = ` +local v = redis.call('get', KEYS[1]) +if v == false then + return redis.error_reply("key does not exist") +end +if v ~= ARGV[1] then + return 0 +end +redis.call('setex', KEYS[1], ARGV[3], ARGV[2]) +return 1 +` +) + +// GoRedisStore implements a Redis-based store using go-redis v8. +type GoRedisStore struct { + client redis.UniversalClient + prefix string +} + +// NewCtx creates a new Redis-based store, using the provided pool to get +// its connections. The keys will have the specified keyPrefix, which +// may be an empty string, and the database index specified by db will +// be selected to store the keys. Any updating operations will reset +// the key TTL to the provided value rounded down to the nearest +// second. Depends on Redis 2.6+ for EVAL support. +func NewCtx(client redis.UniversalClient, keyPrefix string) (*GoRedisStore, error) { + return &GoRedisStore{ + client: client, + prefix: keyPrefix, + }, nil +} + +// GetWithTime returns the value of the key if it is in the store +// or -1 if it does not exist. It also returns the current time at +// the redis server to microsecond precision. +func (r *GoRedisStore) GetWithTime(ctx context.Context, key string) (int64, time.Time, error) { + key = r.prefix + key + + pipe := r.client.Pipeline() + timeCmd := pipe.Time(ctx) + getKeyCmd := pipe.Get(ctx, key) + _, err := pipe.Exec(ctx) + + now, err := timeCmd.Result() + if err != nil { + return 0, now, err + } + + v, err := getKeyCmd.Int64() + if err == redis.Nil { + return -1, now, nil + } else if err != nil { + return 0, now, err + } + + return v, now, nil +} + +// SetIfNotExistsWithTTL sets the value of key only if it is not +// already set in the store it returns whether a new value was set. +// If a new value was set, the ttl in the key is also set, though this +// operation is not performed atomically. +func (r *GoRedisStore) SetIfNotExistsWithTTL(ctx context.Context, key string, value int64, ttl time.Duration) (bool, error) { + key = r.prefix + key + + updated, err := r.client.SetNX(ctx, key, value, 0).Result() + if err != nil { + return false, err + } + + // An `EXPIRE 0` will delete the key immediately, so make sure that we set + // expiry for a minimum of one second out so that our results stay in the + // store. + if ttl < 1*time.Second { + ttl = 1 * time.Second + } + + err = r.client.Expire(ctx, key, ttl).Err() + return updated, err +} + +// CompareAndSwapWithTTL atomically compares the value at key to the +// old value. If it matches, it sets it to the new value and returns +// true. Otherwise, it returns false. If the key does not exist in the +// store, it returns false with no error. If the swap succeeds, the +// ttl for the key is updated atomically. +func (r *GoRedisStore) CompareAndSwapWithTTL(ctx context.Context, key string, old, new int64, ttl time.Duration) (bool, error) { + key = r.prefix + key + + ttlSeconds := int(ttl.Seconds()) + + // An `EXPIRE 0` will delete the key immediately, so make sure that we set + // expiry for a minimum of one second out so that our results stay in the + // store. + if ttlSeconds < 1 { + ttlSeconds = 1 + } + + // result will be 0 or 1 + result, err := r.client.Eval(ctx, redisCASScript, []string{key}, old, new, ttlSeconds).Result() + + var swapped bool + if s, ok := result.(int64); ok { + swapped = s == 1 + } // if not ok, zero value of swapped is false. + + if err != nil { + if strings.Contains(err.Error(), redisCASMissingKey) { + return false, nil + } + return false, err + } + + return swapped, nil +} diff --git a/store/goredisstore.v8/goredisstore_test.go b/store/goredisstore.v8/goredisstore_test.go new file mode 100644 index 0000000..8bd6376 --- /dev/null +++ b/store/goredisstore.v8/goredisstore_test.go @@ -0,0 +1,95 @@ +package goredisstore_test + +import ( + "context" + "log" + "testing" + "time" + + "github.com/go-redis/redis/v8" + "github.com/throttled/throttled/v2" + "github.com/throttled/throttled/v2/store/goredisstore.v8" + "github.com/throttled/throttled/v2/store/storetest" +) + +const ( + redisTestDB = 1 + redisTestPrefix = "throttled-go-redis:" +) + +// Demonstrates that how to initialize a RateLimiter with redis +// using go-redis library. +func ExampleNew() { + // import "github.com/go-redis/redis" + + // Initialize a redis client using go-redis + client := redis.NewClient(&redis.Options{ + PoolSize: 10, // default + IdleTimeout: 30 * time.Second, + Addr: "localhost:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + + // Setup store + store, err := goredisstore.NewCtx(client, "throttled:") + if err != nil { + log.Fatal(err) + } + + // Setup quota + quota := throttled.RateQuota{MaxRate: throttled.PerMin(20), MaxBurst: 5} + + // Then, use store and quota as arguments for NewGCRARateLimiter() + throttled.NewGCRARateLimiterCtx(store, quota) +} + +func TestRedisStore(t *testing.T) { + c, st := setupRedis(t, 0) + defer c.Close() + defer clearRedis(c) + + clearRedis(c) + storetest.TestGCRAStoreCtx(t, st) + storetest.TestGCRAStoreTTLCtx(t, st) +} + +func BenchmarkRedisStore(b *testing.B) { + c, st := setupRedis(b, 0) + defer c.Close() + defer clearRedis(c) + + storetest.BenchmarkGCRAStoreCtx(b, st) +} + +func clearRedis(c *redis.Client) error { + keys, err := c.Keys(context.Background(), redisTestPrefix+"*").Result() + if err != nil { + return err + } + + return c.Del(context.Background(), keys...).Err() +} + +func setupRedis(tb testing.TB, ttl time.Duration) (*redis.Client, *goredisstore.GoRedisStore) { + client := redis.NewClient(&redis.Options{ + PoolSize: 10, // default + IdleTimeout: 30 * time.Second, + Addr: "localhost:6379", + Password: "", // no password set + DB: redisTestDB, // use default DB + }) + + if err := client.Ping(context.Background()).Err(); err != nil { + client.Close() + tb.Skip("redis server not available on localhost port 6379") + } + + st, err := goredisstore.NewCtx(client, redisTestPrefix) + if err != nil { + client.Close() + tb.Fatal(err) + } + + return client, st +} diff --git a/store/goredisstore/goredisstore.go b/store/goredisstore/goredisstore.go index 6f9f5dc..8a9fa32 100644 --- a/store/goredisstore/goredisstore.go +++ b/store/goredisstore/goredisstore.go @@ -2,6 +2,7 @@ package goredisstore // import "github.com/throttled/throttled/v2/store/goredisstore" import ( + "github.com/throttled/throttled/v2" "strings" "time" @@ -42,6 +43,12 @@ func New(client redis.UniversalClient, keyPrefix string) (*GoRedisStore, error) }, nil } +// NewCtx is the version of New that can be used with a context-aware ratelimiter. +func NewCtx(client redis.UniversalClient, keyPrefix string) (throttled.GCRAStoreCtx, error) { + st, err := New(client, keyPrefix) + return throttled.WrapStoreWithContext(st), err +} + // GetWithTime returns the value of the key if it is in the store // or -1 if it does not exist. It also returns the current time at // the redis server to microsecond precision. diff --git a/store/goredisstore/goredisstore_test.go b/store/goredisstore/goredisstore_test.go index be51520..39b06e0 100644 --- a/store/goredisstore/goredisstore_test.go +++ b/store/goredisstore/goredisstore_test.go @@ -31,7 +31,7 @@ func ExampleNew() { }) // Setup store - store, err := goredisstore.New(client, "throttled:") + store, err := goredisstore.NewCtx(client, "throttled:") if err != nil { log.Fatal(err) } @@ -40,7 +40,7 @@ func ExampleNew() { quota := throttled.RateQuota{MaxRate: throttled.PerMin(20), MaxBurst: 5} // Then, use store and quota as arguments for NewGCRARateLimiter() - throttled.NewGCRARateLimiter(store, quota) + throttled.NewGCRARateLimiterCtx(store, quota) } func TestRedisStore(t *testing.T) { @@ -49,8 +49,8 @@ func TestRedisStore(t *testing.T) { defer clearRedis(c) clearRedis(c) - storetest.TestGCRAStore(t, st) - storetest.TestGCRAStoreTTL(t, st) + storetest.TestGCRAStoreCtx(t, st) + storetest.TestGCRAStoreTTLCtx(t, st) } func BenchmarkRedisStore(b *testing.B) { @@ -58,7 +58,7 @@ func BenchmarkRedisStore(b *testing.B) { defer c.Close() defer clearRedis(c) - storetest.BenchmarkGCRAStore(b, st) + storetest.BenchmarkGCRAStoreCtx(b, st) } func clearRedis(c *redis.Client) error { @@ -70,7 +70,7 @@ func clearRedis(c *redis.Client) error { return c.Del(keys...).Err() } -func setupRedis(tb testing.TB, ttl time.Duration) (*redis.Client, *goredisstore.GoRedisStore) { +func setupRedis(tb testing.TB, ttl time.Duration) (*redis.Client, throttled.GCRAStoreCtx) { client := redis.NewClient(&redis.Options{ PoolSize: 10, // default IdleTimeout: 30 * time.Second, @@ -84,7 +84,7 @@ func setupRedis(tb testing.TB, ttl time.Duration) (*redis.Client, *goredisstore. tb.Skip("redis server not available on localhost port 6379") } - st, err := goredisstore.New(client, redisTestPrefix) + st, err := goredisstore.NewCtx(client, redisTestPrefix) if err != nil { client.Close() tb.Fatal(err) diff --git a/store/memstore/memstore.go b/store/memstore/memstore.go index e58e3bd..8b6a4c0 100644 --- a/store/memstore/memstore.go +++ b/store/memstore/memstore.go @@ -2,6 +2,7 @@ package memstore // import "github.com/throttled/throttled/v2/store/memstore" import ( + "github.com/throttled/throttled/v2" "sync" "sync/atomic" "time" @@ -47,6 +48,12 @@ func New(maxKeys int) (*MemStore, error) { return m, nil } +// NewCtx is the version of New that can be used with a context-aware ratelimiter. +func NewCtx(maxKeys int) (throttled.GCRAStoreCtx, error) { + st, err := New(maxKeys) + return throttled.WrapStoreWithContext(st), err +} + // SetTimeNow makes this store use the given function instead of time.Now(). // This is useful for unit tests that use a simulated wallclock. func (ms *MemStore) SetTimeNow(timeNow func() time.Time) { diff --git a/store/memstore/memstore_test.go b/store/memstore/memstore_test.go index 9f988f8..1774822 100644 --- a/store/memstore/memstore_test.go +++ b/store/memstore/memstore_test.go @@ -8,33 +8,33 @@ import ( ) func TestMemStoreLRU(t *testing.T) { - st, err := memstore.New(10) + st, err := memstore.NewCtx(10) if err != nil { t.Fatal(err) } - storetest.TestGCRAStore(t, st) + storetest.TestGCRAStoreCtx(t, st) } func TestMemStoreUnlimited(t *testing.T) { - st, err := memstore.New(10) + st, err := memstore.NewCtx(10) if err != nil { t.Fatal(err) } - storetest.TestGCRAStore(t, st) + storetest.TestGCRAStoreCtx(t, st) } func BenchmarkMemStoreLRU(b *testing.B) { - st, err := memstore.New(10) + st, err := memstore.NewCtx(10) if err != nil { b.Fatal(err) } - storetest.BenchmarkGCRAStore(b, st) + storetest.BenchmarkGCRAStoreCtx(b, st) } func BenchmarkMemStoreUnlimited(b *testing.B) { - st, err := memstore.New(0) + st, err := memstore.NewCtx(0) if err != nil { b.Fatal(err) } - storetest.BenchmarkGCRAStore(b, st) + storetest.BenchmarkGCRAStoreCtx(b, st) } diff --git a/store/redigostore/redigostore.go b/store/redigostore/redigostore.go index 79726ab..1a96f5d 100644 --- a/store/redigostore/redigostore.go +++ b/store/redigostore/redigostore.go @@ -2,6 +2,7 @@ package redigostore // import "github.com/throttled/throttled/v2/store/redigostore" import ( + "github.com/throttled/throttled/v2" "strings" "time" @@ -52,6 +53,12 @@ func New(pool RedigoPool, keyPrefix string, db int) (*RedigoStore, error) { }, nil } +// NewCtx is the version of New that can be used with a context-aware ratelimiter. +func NewCtx(pool *redis.Pool, keyPrefix string, db int) (throttled.GCRAStoreCtx, error) { + st, err := New(pool, keyPrefix, db) + return throttled.WrapStoreWithContext(st), err +} + // GetWithTime returns the value of the key if it is in the store // or -1 if it does not exist. It also returns the current time at // the redis server to microsecond precision. diff --git a/store/redigostore/redigostore_test.go b/store/redigostore/redigostore_test.go index a42da0a..0b623a2 100644 --- a/store/redigostore/redigostore_test.go +++ b/store/redigostore/redigostore_test.go @@ -1,6 +1,7 @@ package redigostore_test import ( + "github.com/throttled/throttled/v2" "testing" "time" @@ -36,8 +37,8 @@ func TestRedisStore(t *testing.T) { defer clearRedis(c) clearRedis(c) - storetest.TestGCRAStore(t, st) - storetest.TestGCRAStoreTTL(t, st) + storetest.TestGCRAStoreCtx(t, st) + storetest.TestGCRAStoreTTLCtx(t, st) } func BenchmarkRedisStore(b *testing.B) { @@ -45,7 +46,7 @@ func BenchmarkRedisStore(b *testing.B) { defer c.Close() defer clearRedis(c) - storetest.BenchmarkGCRAStore(b, st) + storetest.BenchmarkGCRAStoreCtx(b, st) } func clearRedis(c redis.Conn) error { @@ -61,7 +62,7 @@ func clearRedis(c redis.Conn) error { return nil } -func setupRedis(tb testing.TB, ttl time.Duration) (redis.Conn, *redigostore.RedigoStore) { +func setupRedis(tb testing.TB, ttl time.Duration) (redis.Conn, throttled.GCRAStoreCtx) { pool := getPool() c := pool.Get() @@ -75,7 +76,7 @@ func setupRedis(tb testing.TB, ttl time.Duration) (redis.Conn, *redigostore.Redi tb.Fatal(err) } - st, err := redigostore.New(pool, redisTestPrefix, redisTestDB) + st, err := redigostore.NewCtx(pool, redisTestPrefix, redisTestDB) if err != nil { c.Close() tb.Fatal(err) diff --git a/store/storetest/deprecated.go b/store/storetest/deprecated.go new file mode 100644 index 0000000..4a7daf5 --- /dev/null +++ b/store/storetest/deprecated.go @@ -0,0 +1,27 @@ +package storetest + +import ( + "github.com/throttled/throttled/v2" + "testing" +) + +// TestGCRAStore provides an adapter for TestGCRAStoreCtx +// +// Deprecated: implement GCRAStoreCtx and use TestGCRAStoreCtx instead. +func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { + TestGCRAStoreCtx(t, throttled.WrapStoreWithContext(st)) +} + +// TestGCRAStoreTTL provides an adapter for TestGCRAStoreTTLCtx +// +// Deprecated: implement GCRAStoreCtx and use TestGCRAStoreTTLCtx instead. +func TestGCRAStoreTTL(t *testing.T, st throttled.GCRAStore) { + TestGCRAStoreTTLCtx(t, throttled.WrapStoreWithContext(st)) +} + +// BenchmarkGCRAStore provides an adapter for BenchmarkGCRAStoreCtx +// +// Deprecated: implement GCRAStoreCtx and use BenchmarkGCRAStoreCtx instead. +func BenchmarkGCRAStore(b *testing.B, st throttled.GCRAStore) { + BenchmarkGCRAStoreCtx(b, throttled.WrapStoreWithContext(st)) +} diff --git a/store/storetest/storetest.go b/store/storetest/storetest.go index 27696a2..c6cab39 100644 --- a/store/storetest/storetest.go +++ b/store/storetest/storetest.go @@ -2,6 +2,7 @@ package storetest // import "github.com/throttled/throttled/v2/store/storetest" import ( + "context" "math/rand" "strconv" "sync/atomic" @@ -11,12 +12,14 @@ import ( "github.com/throttled/throttled/v2" ) -// TestGCRAStore tests the behavior of a GCRAStore implementation for +// TestGCRAStoreCtx tests the behavior of a GCRAStore implementation for // compliance with the throttled API. It does not require support // for TTLs. -func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { +func TestGCRAStoreCtx(t *testing.T, st throttled.GCRAStoreCtx) { + ctx := context.Background() + // GetWithTime a missing key - if have, _, err := st.GetWithTime("foo"); err != nil { + if have, _, err := st.GetWithTime(ctx, "foo"); err != nil { t.Fatal(err) } else if have != -1 { t.Errorf("expected GetWithTime to return -1 for a missing key but got %d", have) @@ -25,7 +28,7 @@ func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { // SetIfNotExists on a new key want := int64(1) - if set, err := st.SetIfNotExistsWithTTL("foo", want, 0); err != nil { + if set, err := st.SetIfNotExistsWithTTL(ctx, "foo", want, 0); err != nil { t.Fatal(err) } else if !set { t.Errorf("expected SetIfNotExists on an empty key to succeed") @@ -33,7 +36,7 @@ func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { before := time.Now() - if have, now, err := st.GetWithTime("foo"); err != nil { + if have, now, err := st.GetWithTime(ctx, "foo"); err != nil { t.Fatal(err) } else if have != want { t.Errorf("expected GetWithTime to return %d but got %d", want, have) @@ -50,27 +53,27 @@ func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { } // SetIfNotExists on an existing key - if set, err := st.SetIfNotExistsWithTTL("foo", 123, 0); err != nil { + if set, err := st.SetIfNotExistsWithTTL(ctx, "foo", 123, 0); err != nil { t.Fatal(err) } else if set { t.Errorf("expected SetIfNotExists on an existing key to fail") } - if have, _, err := st.GetWithTime("foo"); err != nil { + if have, _, err := st.GetWithTime(ctx, "foo"); err != nil { t.Fatal(err) } else if have != want { t.Errorf("expected GetWithTime to return %d but got %d", want, have) } // SetIfNotExists on a different key - if set, err := st.SetIfNotExistsWithTTL("bar", 456, 0); err != nil { + if set, err := st.SetIfNotExistsWithTTL(ctx, "bar", 456, 0); err != nil { t.Fatal(err) } else if !set { t.Errorf("expected SetIfNotExists on an empty key to succeed") } // Returns the false on a missing key - if swapped, err := st.CompareAndSwapWithTTL("baz", 1, 2, 0); err != nil { + if swapped, err := st.CompareAndSwapWithTTL(ctx, "baz", 1, 2, 0); err != nil { t.Fatal(err) } else if swapped { t.Errorf("expected CompareAndSwap to fail on a missing key") @@ -79,43 +82,44 @@ func TestGCRAStore(t *testing.T, st throttled.GCRAStore) { // Test a successful CAS want = int64(2) - if swapped, err := st.CompareAndSwapWithTTL("foo", 1, want, 0); err != nil { + if swapped, err := st.CompareAndSwapWithTTL(ctx, "foo", 1, want, 0); err != nil { t.Fatal(err) } else if !swapped { t.Errorf("expected CompareAndSwap to succeed") } - if have, _, err := st.GetWithTime("foo"); err != nil { + if have, _, err := st.GetWithTime(ctx, "foo"); err != nil { t.Fatal(err) } else if have != want { t.Errorf("expected GetWithTime to return %d but got %d", want, have) } // Test an unsuccessful CAS - if swapped, err := st.CompareAndSwapWithTTL("foo", 1, 2, 0); err != nil { + if swapped, err := st.CompareAndSwapWithTTL(ctx, "foo", 1, 2, 0); err != nil { t.Fatal(err) } else if swapped { t.Errorf("expected CompareAndSwap to fail") } - if have, _, err := st.GetWithTime("foo"); err != nil { + if have, _, err := st.GetWithTime(ctx, "foo"); err != nil { t.Fatal(err) } else if have != want { t.Errorf("expected GetWithTime to return %d but got %d", want, have) } } -// TestGCRAStoreTTL tests the behavior of TTLs in a GCRAStore implementation. -func TestGCRAStoreTTL(t *testing.T, st throttled.GCRAStore) { +// TestGCRAStoreTTLCtx tests the behavior of TTLs in a GCRAStore implementation. +func TestGCRAStoreTTLCtx(t *testing.T, st throttled.GCRAStoreCtx) { ttl := time.Second want := int64(1) key := "ttl" + ctx := context.Background() - if _, err := st.SetIfNotExistsWithTTL(key, want, ttl); err != nil { + if _, err := st.SetIfNotExistsWithTTL(ctx, key, want, ttl); err != nil { t.Fatal(err) } - if have, _, err := st.GetWithTime(key); err != nil { + if have, _, err := st.GetWithTime(ctx, key); err != nil { t.Fatal(err) } else if have != want { t.Errorf("expected GetWithTime to return %d, got %d", want, have) @@ -124,17 +128,17 @@ func TestGCRAStoreTTL(t *testing.T, st throttled.GCRAStore) { // I can't think of a generic way to test expiration without a sleep time.Sleep(ttl + time.Millisecond) - if have, _, err := st.GetWithTime(key); err != nil { + if have, _, err := st.GetWithTime(ctx, key); err != nil { t.Fatal(err) } else if have != -1 { t.Errorf("expected GetWithTime to fail on an expired key but got %d", have) } } -// BenchmarkGCRAStore runs parallel benchmarks against a GCRAStore implementation. +// BenchmarkGCRAStoreCtx runs parallel benchmarks against a GCRAStore implementation. // Aside from being useful for performance testing, this is useful for finding // race conditions with the Go race detector. -func BenchmarkGCRAStore(b *testing.B, st throttled.GCRAStore) { +func BenchmarkGCRAStoreCtx(b *testing.B, st throttled.GCRAStoreCtx) { seed := int64(42) var attempts, updates int64 @@ -145,21 +149,22 @@ func BenchmarkGCRAStore(b *testing.B, st throttled.GCRAStore) { gen := rand.New(rand.NewSource(seedValue)) for pb.Next() { + ctx := context.Background() key := strconv.FormatInt(gen.Int63n(50), 10) var v int64 var updated bool - v, _, err := st.GetWithTime(key) + v, _, err := st.GetWithTime(ctx, key) if v == -1 { - updated, err = st.SetIfNotExistsWithTTL(key, gen.Int63(), 0) + updated, err = st.SetIfNotExistsWithTTL(ctx, key, gen.Int63(), 0) if err != nil { b.Error(err) } } else if err != nil { b.Error(err) } else { - updated, err = st.CompareAndSwapWithTTL(key, v, gen.Int63(), 0) + updated, err = st.CompareAndSwapWithTTL(ctx, key, v, gen.Int63(), 0) if err != nil { b.Error(err) }