diff --git a/retry.go b/retry.go index 9273f30..299762a 100644 --- a/retry.go +++ b/retry.go @@ -1,38 +1,18 @@ package transport import ( - "crypto/x509" "net/http" - "net/url" - "regexp" "log" "time" "strconv" "math" ) -var ( - // A regular expression to match the error returned by net/http when the - // configured number of redirects is exhausted. This error isn't typed - // specifically so we resort to matching on the error string. - tooManyRedirectsRe = regexp.MustCompile(`stopped after \d+ redirects\z`) - - // A regular expression to match the error returned by net/http when the - // scheme specified in the URL is invalid. This error isn't typed - // specifically so we resort to matching on the error string. - invalidSchemeRe = regexp.MustCompile(`unsupported protocol scheme`) - - // A regular expression to match the error returned by net/http when the - // TLS certificate is not trusted. This error isn't typed - // specifically so we resort to matching on the error string. - untrustedCertificateRe = regexp.MustCompile(`certificate is not trusted`) -) - func Retry(baseTransport http.RoundTripper, maxRetries int) func(http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper { return RoundTripFunc(func(req *http.Request) (resp *http.Response, err error) { defer func() { - if isRetryable(err, resp) { + if isRetryable(resp) { for i := 1; i <= maxRetries; i++ { wait := backOff(resp, i) @@ -50,13 +30,12 @@ func Retry(baseTransport http.RoundTripper, maxRetries int) func(http.RoundTripp startTime := time.Now() resp, err = baseTransport.RoundTrip(req) - if isRetryable(err, resp) { - log.Printf("retrying %d request: %s %s", i, req.Method, req.URL) - log.Printf("response (%v): %v %s", time.Since(startTime), resp.Status, resp.Request.URL) - continue - } else { + if !isRetryable(resp) { break } + + log.Printf("retrying %d request: %s %s", i, req.Method, req.URL) + log.Printf("response (%v): %v %s", time.Since(startTime), resp.Status, resp.Request.URL) } } }() @@ -87,33 +66,11 @@ func backOff(resp *http.Response, attempt int) time.Duration { return sleep } -func isRetryable(err error, resp *http.Response) bool { +func isRetryable(resp *http.Response) bool { if resp == nil { return false } - // any error returned from Client.Do will be *url.Error - if serverErr, ok := err.(*url.Error); ok { - // Too many redirects. - if tooManyRedirectsRe.MatchString(serverErr.Error()) { - return false - } - - // Invalid protocol scheme. - if invalidSchemeRe.MatchString(serverErr.Error()) { - return false - } - - // TLS cert verification failure. - if untrustedCertificateRe.MatchString(serverErr.Error()) { - return false - } - - if _, ok := serverErr.Err.(x509.UnknownAuthorityError); ok { - return false - } - } - // 429 Too Many Requests is recoverable. Sometimes the server puts // Retry-After response header to indicate when the server is will be available again if resp.StatusCode == http.StatusTooManyRequests { diff --git a/transport.go b/transport.go index 0a9ed02..299f9a0 100644 --- a/transport.go +++ b/transport.go @@ -26,10 +26,10 @@ func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { // authClient := http.Client{ // Transport: transport.Chain( // http.DefaultTransport, -// transport.SetHeader("User-Agent", userAgent), -// transport.SetHeader("Authorization", authHeader), -// transport.SetHeader("x-extra", "value"), -// transport.TraceID, +// transport.SetHeader("User-Agent", userAgent), +// transport.SetHeader("Authorization", authHeader), +// transport.SetHeader("x-extra", "value"), +// transport.TraceID, // ), // Timeout: 15 * time.Second, // }