diff --git a/.github/.codecov.yml b/.github/.codecov.yml index 5e90d523..191a126b 100644 --- a/.github/.codecov.yml +++ b/.github/.codecov.yml @@ -15,4 +15,7 @@ coverage: status: project: default: - target: 89% \ No newline at end of file + target: 89% + patch: + default: + target: 90% \ No newline at end of file diff --git a/revocation/internal/ocsp/ocsp.go b/revocation/internal/ocsp/ocsp.go index 25410ed3..80174cf7 100644 --- a/revocation/internal/ocsp/ocsp.go +++ b/revocation/internal/ocsp/ocsp.go @@ -17,6 +17,7 @@ package ocsp import ( "bytes" + "context" "crypto" "crypto/x509" "crypto/x509/pkix" @@ -52,7 +53,7 @@ const ( ) // CertCheckStatus checks the revocation status of a certificate using OCSP -func CertCheckStatus(cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult { +func CertCheckStatus(ctx context.Context, cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult { if !Supported(cert) { // OCSP not enabled for this certificate. return &result.CertRevocationResult{ @@ -65,7 +66,7 @@ func CertCheckStatus(cert, issuer *x509.Certificate, opts CertCheckStatusOptions serverResults := make([]*result.ServerResult, len(ocspURLs)) for serverIndex, server := range ocspURLs { - serverResult := checkStatusFromServer(cert, issuer, server, opts) + serverResult := checkStatusFromServer(ctx, cert, issuer, server, opts) if serverResult.Result == result.ResultOK || serverResult.Result == result.ResultRevoked || (serverResult.Result == result.ResultUnknown && errors.Is(serverResult.Error, UnknownStatusError{})) { @@ -84,7 +85,7 @@ func Supported(cert *x509.Certificate) bool { return cert != nil && len(cert.OCSPServer) > 0 } -func checkStatusFromServer(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) *result.ServerResult { +func checkStatusFromServer(ctx context.Context, cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) *result.ServerResult { // Check valid server if serverURL, err := url.Parse(server); err != nil || !strings.EqualFold(serverURL.Scheme, "http") { // This function is only able to check servers that are accessible via HTTP @@ -92,7 +93,7 @@ func checkStatusFromServer(cert, issuer *x509.Certificate, server string, opts C } // Create OCSP Request - resp, err := executeOCSPCheck(cert, issuer, server, opts) + resp, err := executeOCSPCheck(ctx, cert, issuer, server, opts) if err != nil { // If there is a server error, attempt all servers before determining what to return // to the user @@ -142,7 +143,7 @@ func extensionsToMap(extensions []pkix.Extension) map[string][]byte { return extensionMap } -func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) (*ocsp.Response, error) { +func executeOCSPCheck(ctx context.Context, cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) (*ocsp.Response, error) { // TODO: Look into other alternatives for specifying the Hash // https://github.com/notaryproject/notation-core-go/issues/139 // The following do not support SHA256 hashes: @@ -168,18 +169,25 @@ func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCh if err != nil { return nil, GenericError{Err: err} } - resp, err = opts.HTTPClient.Get(reqURL) + var httpReq *http.Request + httpReq, err = http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return nil, err + } + resp, err = opts.HTTPClient.Do(httpReq) } else { - resp, err = postRequest(ocspRequest, server, opts.HTTPClient) + resp, err = postRequest(ctx, ocspRequest, server, opts.HTTPClient) } } else { - resp, err = postRequest(ocspRequest, server, opts.HTTPClient) + resp, err = postRequest(ctx, ocspRequest, server, opts.HTTPClient) } if err != nil { var urlErr *url.Error if errors.As(err, &urlErr) && urlErr.Timeout() { - return nil, TimeoutError{} + return nil, TimeoutError{ + timeout: opts.HTTPClient.Timeout, + } } return nil, GenericError{Err: err} } @@ -210,9 +218,13 @@ func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCh return ocsp.ParseResponseForCert(body, cert, issuer) } -func postRequest(req []byte, server string, httpClient *http.Client) (*http.Response, error) { - reader := bytes.NewReader(req) - return httpClient.Post(server, "application/ocsp-request", reader) +func postRequest(ctx context.Context, req []byte, server string, httpClient *http.Client) (*http.Response, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, server, bytes.NewReader(req)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Content-Type", "application/ocsp-request") + return httpClient.Do(httpReq) } func toServerResult(server string, err error) *result.ServerResult { diff --git a/revocation/internal/ocsp/ocsp_test.go b/revocation/internal/ocsp/ocsp_test.go index f1f0f118..a10902a7 100644 --- a/revocation/internal/ocsp/ocsp_test.go +++ b/revocation/internal/ocsp/ocsp_test.go @@ -14,9 +14,11 @@ package ocsp import ( + "context" "crypto/x509" "fmt" "net/http" + "net/url" "strings" "testing" "time" @@ -75,6 +77,7 @@ func TestCheckStatus(t *testing.T) { ocspServer := revokableCertTuple.Cert.OCSPServer[0] revokableChain := []*x509.Certificate{revokableCertTuple.Cert, revokableIssuerTuple.Cert} testChain := []testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple} + ctx := context.Background() t.Run("check non-revoked cert", func(t *testing.T) { client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good}, nil, true) @@ -83,7 +86,7 @@ func TestCheckStatus(t *testing.T) { HTTPClient: client, } - certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts) expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) }) @@ -94,7 +97,7 @@ func TestCheckStatus(t *testing.T) { HTTPClient: client, } - certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts) expectedCertResults := []*result.CertRevocationResult{{ Result: result.ResultUnknown, ServerResults: []*result.ServerResult{ @@ -110,7 +113,7 @@ func TestCheckStatus(t *testing.T) { HTTPClient: client, } - certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts) expectedCertResults := []*result.CertRevocationResult{{ Result: result.ResultRevoked, ServerResults: []*result.ServerResult{ @@ -127,13 +130,13 @@ func TestCheckStatus(t *testing.T) { HTTPClient: client, } - certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts) + certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts) expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)} validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t) }) t.Run("certificate doesn't support OCSP", func(t *testing.T) { - ocspResult := CertCheckStatus(&x509.Certificate{}, revokableIssuerTuple.Cert, CertCheckStatusOptions{}) + ocspResult := CertCheckStatus(ctx, &x509.Certificate{}, revokableIssuerTuple.Cert, CertCheckStatusOptions{}) expectedResult := &result.CertRevocationResult{ Result: result.ResultNonRevokable, ServerResults: []*result.ServerResult{toServerResult("", NoServerError{})}, @@ -146,10 +149,11 @@ func TestCheckStatus(t *testing.T) { func TestCheckStatusFromServer(t *testing.T) { revokableCertTuple := testhelper.GetRevokableRSALeafCertificate() revokableIssuerTuple := testhelper.GetRSARootCertificate() + ctx := context.Background() t.Run("server url is not http", func(t *testing.T) { server := "https://example.com" - serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{}) + serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{}) expectedResult := toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", "https")}) if serverResult.Result != expectedResult.Result { t.Errorf("Expected Result to be %s, but got %s", expectedResult.Result, serverResult.Result) @@ -166,7 +170,7 @@ func TestCheckStatusFromServer(t *testing.T) { t.Run("request error", func(t *testing.T) { server := "http://example.com" - serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ HTTPClient: &http.Client{ Transport: &failedTransport{}, }, @@ -180,7 +184,7 @@ func TestCheckStatusFromServer(t *testing.T) { t.Run("ocsp expired", func(t *testing.T) { client := testhelper.MockClient([]testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) server := "http://example.com/expired_ocsp" - serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ HTTPClient: client, }) errorMessage := "expired OCSP response" @@ -188,21 +192,77 @@ func TestCheckStatusFromServer(t *testing.T) { t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error) } }) + + t.Run("ocsp request roundtrip failed", func(t *testing.T) { + client := testhelper.MockClient([]testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true) + server := "http://example.com" + serverResult := checkStatusFromServer(nil, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + HTTPClient: client, + }) + errorMessage := "net/http: nil Context" + if !strings.Contains(serverResult.Error.Error(), errorMessage) { + t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error) + } + }) + + t.Run("ocsp request roundtrip timeout", func(t *testing.T) { + server := "http://example.com" + serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{ + HTTPClient: &http.Client{ + Timeout: 1 * time.Second, + Transport: &failedTransport{ + timeout: true, + }, + }, + }) + errorMessage := "exceeded timeout threshold of 1.00 seconds for OCSP check" + if !strings.Contains(serverResult.Error.Error(), errorMessage) { + t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error) + } + }) } func TestPostRequest(t *testing.T) { + t.Run("failed to generate request", func(t *testing.T) { + _, err := postRequest(nil, nil, "http://example.com", &http.Client{ + Transport: &failedTransport{}, + }) + expectedErrMsg := "net/http: nil Context" + if err == nil || err.Error() != expectedErrMsg { + t.Errorf("Expected error %s, but got %s", expectedErrMsg, err) + } + }) + t.Run("failed to execute request", func(t *testing.T) { - _, err := postRequest(nil, "http://example.com", &http.Client{ + _, err := postRequest(context.Background(), nil, "http://example.com", &http.Client{ Transport: &failedTransport{}, }) - if err == nil { - t.Errorf("Expected error, but got nil") + expectedErrMsg := "Post \"http://example.com\": failed to execute request" + if err == nil || err.Error() != expectedErrMsg { + t.Errorf("Expected error %s, but got %s", expectedErrMsg, err) } }) } -type failedTransport struct{} +type testTimeoutError struct{} + +func (e testTimeoutError) Error() string { + return "test timeout" +} + +func (e testTimeoutError) Timeout() bool { + return true +} + +type failedTransport struct { + timeout bool +} func (f *failedTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if f.timeout { + return nil, &url.Error{ + Err: testTimeoutError{}, + } + } return nil, fmt.Errorf("failed to execute request") } diff --git a/revocation/ocsp/ocsp.go b/revocation/ocsp/ocsp.go index c2274f75..2a524b5c 100644 --- a/revocation/ocsp/ocsp.go +++ b/revocation/ocsp/ocsp.go @@ -16,6 +16,7 @@ package ocsp import ( + "context" "crypto/x509" "errors" "net/http" @@ -61,12 +62,13 @@ func CheckStatus(opts Options) ([]*result.CertRevocationResult, error) { // Check status for each cert in cert chain var wg sync.WaitGroup + ctx := context.Background() for i, cert := range opts.CertChain[:len(opts.CertChain)-1] { wg.Add(1) // Assume cert chain is accurate and next cert in chain is the issuer go func(i int, cert *x509.Certificate) { defer wg.Done() - certResults[i] = ocsp.CertCheckStatus(cert, opts.CertChain[i+1], certCheckStatusOptions) + certResults[i] = ocsp.CertCheckStatus(ctx, cert, opts.CertChain[i+1], certCheckStatusOptions) }(i, cert) } // Last is root cert, which will never be revoked by OCSP diff --git a/revocation/revocation.go b/revocation/revocation.go index f249d915..69103735 100644 --- a/revocation/revocation.go +++ b/revocation/revocation.go @@ -209,7 +209,7 @@ func (r *revocation) ValidateContext(ctx context.Context, validateContextOpts Va } }() - ocspResult := ocsp.CertCheckStatus(cert, certChain[i+1], ocspOpts) + ocspResult := ocsp.CertCheckStatus(ctx, cert, certChain[i+1], ocspOpts) if ocspResult != nil && ocspResult.Result == result.ResultUnknown && crl.Supported(cert) { // try CRL check if OCSP serverResult is unknown serverResult := crl.CertCheckStatus(ctx, cert, certChain[i+1], crlOpts)