diff --git a/client.go b/client.go index 299cce4..c6508e4 100644 --- a/client.go +++ b/client.go @@ -32,8 +32,7 @@ const ( // Client - base api client type Client struct { - clientMu sync.Mutex // clientMu protects the client during calls that modify the CheckRedirect func. - client *http.Client // HTTP client used to communicate with the API. + client *http.Client // HTTP client used to communicate with the API. apiBase *url.URL // apiBase the base used when communicating with the API. apiVersion string // apiVersion the version used when communicating with the API. @@ -181,6 +180,15 @@ func (c *Client) newRequest(method, path string, body interface{}) (*http.Reques func (c *Client) do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) { req = req.WithContext(ctx) + + // If we've hit rate limit, don't make further requests before Reset time. + if err := c.checkRateLimitBeforeDo(req); err != nil { + return &Response{ + Response: err.Response, + Rate: err.Rate, + }, err + } + resp, err := c.client.Do(req) if err != nil { select { @@ -193,6 +201,10 @@ func (c *Client) do(ctx context.Context, req *http.Request, v interface{}) (*Res response := newResponse(resp) + c.rateMu.Lock() + c.rateLimits = response.Rate + c.rateMu.Unlock() + err = checkResponse(resp) if err != nil { defer resp.Body.Close() @@ -209,6 +221,33 @@ func (c *Client) do(ctx context.Context, req *http.Request, v interface{}) (*Res return response, err } +// checkRateLimitBeforeDo does not make any network calls, but uses existing knowledge from +// current client state in order to quickly check if *RateLimitError can be immediately returned +// from Client.do, and if so, returns it so that Client.do can skip making a network API call unnecessarily. +// Otherwise it returns nil, and Client.do should proceed normally. +func (c *Client) checkRateLimitBeforeDo(req *http.Request) *RateLimitError { + c.rateMu.Lock() + rate := c.rateLimits + c.rateMu.Unlock() + if rate.Remaining == 0 && rate.RetryAfter != nil && time.Now().Before(time.Now().Add(*rate.RetryAfter)) { + // Create a fake response. + resp := &http.Response{ + Status: http.StatusText(http.StatusForbidden), + StatusCode: http.StatusForbidden, + Request: req, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + } + return &RateLimitError{ + Rate: rate, + Response: resp, + Message: fmt.Sprintf("API rate limit of %v still exceeded until %v, not making remote request.", rate.Limit, rate.RetryAfter), + } + } + + return nil +} + // newResponse creates a new Response for the provided http.Response. func newResponse(r *http.Response) *Response { response := &Response{Response: r} diff --git a/client_test.go b/client_test.go index 25284ad..2b18595 100644 --- a/client_test.go +++ b/client_test.go @@ -345,3 +345,51 @@ func TestWillHandleAPIRateError(t *testing.T) { assert.Equal(t, "GET https://connect.mailerlite.com/api/subscribers: 429 Too Many Attempts. [retry after 59s]", err.Error()) } + +func TestWillHandleAPIRateErrorAndNoRemoteCall(t *testing.T) { + client := mailerlite.NewClient(testKey) + + header := http.Header{} + header.Set(mailerlite.HeaderRateLimit, "120") + header.Set(mailerlite.HeaderRateRemaining, "0") + header.Set(mailerlite.HeaderRateRetryAfter, "59") + + testClient := NewTestClient(func(req *http.Request) *http.Response { + res := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Request: req, + Header: header, + Body: io.NopCloser(strings.NewReader(`{"message": "Too Many Attempts."}`)), + } + + return res + }) + + ctx := context.TODO() + + client.SetHttpClient(testClient) + + listOptions := &mailerlite.ListSubscriberOptions{} + + _, res, err := client.Subscriber.List(ctx, listOptions) + + assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) + assert.IsType(t, &mailerlite.RateLimitError{}, err) + + retryAfter := time.Duration(59) * time.Second + + if err, ok := err.(*mailerlite.RateLimitError); ok { + assert.Equal(t, "Too Many Attempts.", err.Message) + assert.Equal(t, 0, err.Rate.Remaining) + assert.Equal(t, &retryAfter, err.Rate.RetryAfter) + } + + assert.Equal(t, "GET https://connect.mailerlite.com/api/subscribers: 429 Too Many Attempts. [retry after 59s]", err.Error()) + + _, res, err = client.Subscriber.List(ctx, listOptions) + assert.Equal(t, http.StatusForbidden, res.StatusCode) + assert.IsType(t, &mailerlite.RateLimitError{}, err) + + assert.Equal(t, "GET https://connect.mailerlite.com/api/subscribers: 403 API rate limit of 120 still exceeded until 59s, not making remote request. [retry after 59s]", err.Error()) + +}