Skip to content

Commit

Permalink
fix: ocsp revocation check with context (#235)
Browse files Browse the repository at this point in the history
This PR adds `context` to OCSP revocation check. Resolves #223

Signed-off-by: Patrick Zheng <[email protected]>
  • Loading branch information
Two-Hearts authored Oct 11, 2024
1 parent e90546b commit 3067ab1
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 27 deletions.
5 changes: 4 additions & 1 deletion .github/.codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ coverage:
status:
project:
default:
target: 89%
target: 89%
patch:
default:
target: 90%
36 changes: 24 additions & 12 deletions revocation/internal/ocsp/ocsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package ocsp

import (
"bytes"
"context"
"crypto"
"crypto/x509"
"crypto/x509/pkix"
Expand Down Expand Up @@ -52,7 +53,7 @@ const (
)

// CertCheckStatus checks the revocation status of a certificate using OCSP
func CertCheckStatus(cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult {
func CertCheckStatus(ctx context.Context, cert, issuer *x509.Certificate, opts CertCheckStatusOptions) *result.CertRevocationResult {
if !Supported(cert) {
// OCSP not enabled for this certificate.
return &result.CertRevocationResult{
Expand All @@ -65,7 +66,7 @@ func CertCheckStatus(cert, issuer *x509.Certificate, opts CertCheckStatusOptions

serverResults := make([]*result.ServerResult, len(ocspURLs))
for serverIndex, server := range ocspURLs {
serverResult := checkStatusFromServer(cert, issuer, server, opts)
serverResult := checkStatusFromServer(ctx, cert, issuer, server, opts)
if serverResult.Result == result.ResultOK ||
serverResult.Result == result.ResultRevoked ||
(serverResult.Result == result.ResultUnknown && errors.Is(serverResult.Error, UnknownStatusError{})) {
Expand All @@ -84,15 +85,15 @@ func Supported(cert *x509.Certificate) bool {
return cert != nil && len(cert.OCSPServer) > 0
}

func checkStatusFromServer(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) *result.ServerResult {
func checkStatusFromServer(ctx context.Context, cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) *result.ServerResult {
// Check valid server
if serverURL, err := url.Parse(server); err != nil || !strings.EqualFold(serverURL.Scheme, "http") {
// This function is only able to check servers that are accessible via HTTP
return toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", serverURL.Scheme)})
}

// Create OCSP Request
resp, err := executeOCSPCheck(cert, issuer, server, opts)
resp, err := executeOCSPCheck(ctx, cert, issuer, server, opts)
if err != nil {
// If there is a server error, attempt all servers before determining what to return
// to the user
Expand Down Expand Up @@ -142,7 +143,7 @@ func extensionsToMap(extensions []pkix.Extension) map[string][]byte {
return extensionMap
}

func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) (*ocsp.Response, error) {
func executeOCSPCheck(ctx context.Context, cert, issuer *x509.Certificate, server string, opts CertCheckStatusOptions) (*ocsp.Response, error) {
// TODO: Look into other alternatives for specifying the Hash
// https://github.com/notaryproject/notation-core-go/issues/139
// The following do not support SHA256 hashes:
Expand All @@ -168,18 +169,25 @@ func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCh
if err != nil {
return nil, GenericError{Err: err}
}
resp, err = opts.HTTPClient.Get(reqURL)
var httpReq *http.Request
httpReq, err = http.NewRequestWithContext(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return nil, err
}
resp, err = opts.HTTPClient.Do(httpReq)
} else {
resp, err = postRequest(ocspRequest, server, opts.HTTPClient)
resp, err = postRequest(ctx, ocspRequest, server, opts.HTTPClient)
}
} else {
resp, err = postRequest(ocspRequest, server, opts.HTTPClient)
resp, err = postRequest(ctx, ocspRequest, server, opts.HTTPClient)
}

if err != nil {
var urlErr *url.Error
if errors.As(err, &urlErr) && urlErr.Timeout() {
return nil, TimeoutError{}
return nil, TimeoutError{
timeout: opts.HTTPClient.Timeout,
}
}
return nil, GenericError{Err: err}
}
Expand Down Expand Up @@ -210,9 +218,13 @@ func executeOCSPCheck(cert, issuer *x509.Certificate, server string, opts CertCh
return ocsp.ParseResponseForCert(body, cert, issuer)
}

func postRequest(req []byte, server string, httpClient *http.Client) (*http.Response, error) {
reader := bytes.NewReader(req)
return httpClient.Post(server, "application/ocsp-request", reader)
func postRequest(ctx context.Context, req []byte, server string, httpClient *http.Client) (*http.Response, error) {
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, server, bytes.NewReader(req))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
return httpClient.Do(httpReq)
}

func toServerResult(server string, err error) *result.ServerResult {
Expand Down
84 changes: 72 additions & 12 deletions revocation/internal/ocsp/ocsp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
package ocsp

import (
"context"
"crypto/x509"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -75,6 +77,7 @@ func TestCheckStatus(t *testing.T) {
ocspServer := revokableCertTuple.Cert.OCSPServer[0]
revokableChain := []*x509.Certificate{revokableCertTuple.Cert, revokableIssuerTuple.Cert}
testChain := []testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}
ctx := context.Background()

t.Run("check non-revoked cert", func(t *testing.T) {
client := testhelper.MockClient(testChain, []ocsp.ResponseStatus{ocsp.Good}, nil, true)
Expand All @@ -83,7 +86,7 @@ func TestCheckStatus(t *testing.T) {
HTTPClient: client,
}

certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts)
certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts)
expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)}
validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t)
})
Expand All @@ -94,7 +97,7 @@ func TestCheckStatus(t *testing.T) {
HTTPClient: client,
}

certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts)
certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts)
expectedCertResults := []*result.CertRevocationResult{{
Result: result.ResultUnknown,
ServerResults: []*result.ServerResult{
Expand All @@ -110,7 +113,7 @@ func TestCheckStatus(t *testing.T) {
HTTPClient: client,
}

certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts)
certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts)
expectedCertResults := []*result.CertRevocationResult{{
Result: result.ResultRevoked,
ServerResults: []*result.ServerResult{
Expand All @@ -127,13 +130,13 @@ func TestCheckStatus(t *testing.T) {
HTTPClient: client,
}

certResult := CertCheckStatus(revokableChain[0], revokableChain[1], opts)
certResult := CertCheckStatus(ctx, revokableChain[0], revokableChain[1], opts)
expectedCertResults := []*result.CertRevocationResult{getOKCertResult(ocspServer)}
validateEquivalentCertResults([]*result.CertRevocationResult{certResult}, expectedCertResults, t)
})

t.Run("certificate doesn't support OCSP", func(t *testing.T) {
ocspResult := CertCheckStatus(&x509.Certificate{}, revokableIssuerTuple.Cert, CertCheckStatusOptions{})
ocspResult := CertCheckStatus(ctx, &x509.Certificate{}, revokableIssuerTuple.Cert, CertCheckStatusOptions{})
expectedResult := &result.CertRevocationResult{
Result: result.ResultNonRevokable,
ServerResults: []*result.ServerResult{toServerResult("", NoServerError{})},
Expand All @@ -146,10 +149,11 @@ func TestCheckStatus(t *testing.T) {
func TestCheckStatusFromServer(t *testing.T) {
revokableCertTuple := testhelper.GetRevokableRSALeafCertificate()
revokableIssuerTuple := testhelper.GetRSARootCertificate()
ctx := context.Background()

t.Run("server url is not http", func(t *testing.T) {
server := "https://example.com"
serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{})
serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{})
expectedResult := toServerResult(server, GenericError{Err: fmt.Errorf("OCSPServer protocol %s is not supported", "https")})
if serverResult.Result != expectedResult.Result {
t.Errorf("Expected Result to be %s, but got %s", expectedResult.Result, serverResult.Result)
Expand All @@ -166,7 +170,7 @@ func TestCheckStatusFromServer(t *testing.T) {

t.Run("request error", func(t *testing.T) {
server := "http://example.com"
serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
HTTPClient: &http.Client{
Transport: &failedTransport{},
},
Expand All @@ -180,29 +184,85 @@ func TestCheckStatusFromServer(t *testing.T) {
t.Run("ocsp expired", func(t *testing.T) {
client := testhelper.MockClient([]testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true)
server := "http://example.com/expired_ocsp"
serverResult := checkStatusFromServer(revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
HTTPClient: client,
})
errorMessage := "expired OCSP response"
if !strings.Contains(serverResult.Error.Error(), errorMessage) {
t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error)
}
})

t.Run("ocsp request roundtrip failed", func(t *testing.T) {
client := testhelper.MockClient([]testhelper.RSACertTuple{revokableCertTuple, revokableIssuerTuple}, []ocsp.ResponseStatus{ocsp.Good}, nil, true)
server := "http://example.com"
serverResult := checkStatusFromServer(nil, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
HTTPClient: client,
})
errorMessage := "net/http: nil Context"
if !strings.Contains(serverResult.Error.Error(), errorMessage) {
t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error)
}
})

t.Run("ocsp request roundtrip timeout", func(t *testing.T) {
server := "http://example.com"
serverResult := checkStatusFromServer(ctx, revokableCertTuple.Cert, revokableIssuerTuple.Cert, server, CertCheckStatusOptions{
HTTPClient: &http.Client{
Timeout: 1 * time.Second,
Transport: &failedTransport{
timeout: true,
},
},
})
errorMessage := "exceeded timeout threshold of 1.00 seconds for OCSP check"
if !strings.Contains(serverResult.Error.Error(), errorMessage) {
t.Errorf("Expected Error to contain %v, but got %v", errorMessage, serverResult.Error)
}
})
}

func TestPostRequest(t *testing.T) {
t.Run("failed to generate request", func(t *testing.T) {
_, err := postRequest(nil, nil, "http://example.com", &http.Client{
Transport: &failedTransport{},
})
expectedErrMsg := "net/http: nil Context"
if err == nil || err.Error() != expectedErrMsg {
t.Errorf("Expected error %s, but got %s", expectedErrMsg, err)
}
})

t.Run("failed to execute request", func(t *testing.T) {
_, err := postRequest(nil, "http://example.com", &http.Client{
_, err := postRequest(context.Background(), nil, "http://example.com", &http.Client{
Transport: &failedTransport{},
})
if err == nil {
t.Errorf("Expected error, but got nil")
expectedErrMsg := "Post \"http://example.com\": failed to execute request"
if err == nil || err.Error() != expectedErrMsg {
t.Errorf("Expected error %s, but got %s", expectedErrMsg, err)
}
})
}

type failedTransport struct{}
type testTimeoutError struct{}

func (e testTimeoutError) Error() string {
return "test timeout"
}

func (e testTimeoutError) Timeout() bool {
return true
}

type failedTransport struct {
timeout bool
}

func (f *failedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if f.timeout {
return nil, &url.Error{
Err: testTimeoutError{},
}
}
return nil, fmt.Errorf("failed to execute request")
}
4 changes: 3 additions & 1 deletion revocation/ocsp/ocsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package ocsp

import (
"context"
"crypto/x509"
"errors"
"net/http"
Expand Down Expand Up @@ -61,12 +62,13 @@ func CheckStatus(opts Options) ([]*result.CertRevocationResult, error) {

// Check status for each cert in cert chain
var wg sync.WaitGroup
ctx := context.Background()
for i, cert := range opts.CertChain[:len(opts.CertChain)-1] {
wg.Add(1)
// Assume cert chain is accurate and next cert in chain is the issuer
go func(i int, cert *x509.Certificate) {
defer wg.Done()
certResults[i] = ocsp.CertCheckStatus(cert, opts.CertChain[i+1], certCheckStatusOptions)
certResults[i] = ocsp.CertCheckStatus(ctx, cert, opts.CertChain[i+1], certCheckStatusOptions)
}(i, cert)
}
// Last is root cert, which will never be revoked by OCSP
Expand Down
2 changes: 1 addition & 1 deletion revocation/revocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (r *revocation) ValidateContext(ctx context.Context, validateContextOpts Va
}
}()

ocspResult := ocsp.CertCheckStatus(cert, certChain[i+1], ocspOpts)
ocspResult := ocsp.CertCheckStatus(ctx, cert, certChain[i+1], ocspOpts)
if ocspResult != nil && ocspResult.Result == result.ResultUnknown && crl.Supported(cert) {
// try CRL check if OCSP serverResult is unknown
serverResult := crl.CertCheckStatus(ctx, cert, certChain[i+1], crlOpts)
Expand Down

0 comments on commit 3067ab1

Please sign in to comment.