Skip to content

Commit

Permalink
attestedtls: improved error handling and log messages
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Ott <[email protected]>
  • Loading branch information
smo4201 committed Oct 16, 2024
1 parent dc498c6 commit 837ceac
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 21 deletions.
23 changes: 21 additions & 2 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,25 @@ const (
TypeTLSCert uint32 = 5
)

func TypeToString(t uint32) string {
switch t {
case TypeError:
return "Error"
case TypeAttest:
return "Attest"
case TypeVerify:
return "Verify"
case TypeMeasure:
return "Measure"
case TypeTLSSign:
return "TLSSign"
case TypeTLSCert:
return "TLSCert"
default:
return "Unknown"
}
}

// Converts Protobuf hashtype to crypto.SignerOpts
func HashToSignerOpts(hashtype HashFunction, pssOpts *PSSOptions) (crypto.SignerOpts, error) {
var hash crypto.Hash
Expand Down Expand Up @@ -236,7 +255,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) {
payloadLen, MaxMsgLen)
}

log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen)
log.Tracef("Decoded header. Type %v, length %v", TypeToString(msgType), payloadLen)

// Read payload
payload := bytes.NewBuffer(nil)
Expand Down Expand Up @@ -297,7 +316,7 @@ func Send(conn net.Conn, payload []byte, t uint32) error {
return fmt.Errorf("could only send %v of %v bytes", n, len(buf))
}

log.Tracef("Sending payload type %v length %v", t, uint32(len(payload)))
log.Tracef("Sending payload type %v length %v", TypeToString(t), uint32(len(payload)))

n, err = conn.Write(payload)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion attestationreport/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (hws *hwSigner) SignPayload(payload []byte, alg jose.SignatureAlgorithm) ([
}
hashed := hasher.Sum(nil)

// sign payload
// Sign payload
switch alg {
case jose.ES256, jose.ES384, jose.ES512:
// Obtain signature
Expand Down
8 changes: 4 additions & 4 deletions attestedtls/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error {
if err != nil {
return fmt.Errorf("failed to send skip client Attestation: %w", err)
}
log.Debug("Skipping client-side attestation")
log.Debug("Skipping client-side attestation: no attestation report generation required")
}

// Fetch attestation report from listener
Expand Down Expand Up @@ -96,10 +96,10 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error {
// optional: attest server
if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server {
// Obtain own attestation report from local cmcd
log.Trace("Attesting the Server")
log.Trace("Listener: Fetching attestation report from cmcd")
resp, err := cc.CmcApi.obtainAR(cc, chbindings)
if err != nil {
return fmt.Errorf("could not obtain AR of Listener : %w", err)
return fmt.Errorf("could not obtain listener attestation report: %w", err)
}

// Send own attestation report to dialer. This is done asynchronously to
Expand Down Expand Up @@ -132,7 +132,7 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error {
// optional: Wait for attestation report from client
if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client {
// Verify AR from dialer with own channel bindings
log.Trace("Verifying attestation report from dialer...")
log.Trace("Listener: Verifying attestation report from dialer...")
err = cc.CmcApi.verifyAR(chbindings, report, cc)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion attestedtls/libapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (a LibApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) {
log.Debug("Prover: Signing Attestation Report")
signedReport, err := generate.Sign(report, cc.Cmc.Drivers[0], cc.Cmc.Serializer)
if err != nil {
return nil, errors.New("prover: failed to sign Attestion Report ")
return nil, fmt.Errorf("prover: failed to sign attestation reoprt: %w", err)
}

return signedReport, nil
Expand Down
70 changes: 57 additions & 13 deletions attestedtls/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ func init() {
func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) {

// Establish connection
log.Tracef("Contacting cmcd via %v on %v", cc.Network, cc.CmcAddr)
log.Tracef("Sending attestation request to cmcd via %v on %v", cc.Network, cc.CmcAddr)
conn, err := net.Dial(cc.Network, cc.CmcAddr)
if err != nil {
return nil, fmt.Errorf("error dialing: %w", err)
return nil, fmt.Errorf("error dialing cmcd: %w", err)
}

req := &api.AttestationRequest{
Expand All @@ -63,19 +63,30 @@ func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) {
// Send request
err = api.Send(conn, payload, api.TypeAttest)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
return nil, fmt.Errorf("failed to send request to cmcd: %w", err)
}

// Read reply
payload, _, err = api.Receive(conn)
payload, mtype, err := api.Receive(conn)
if err != nil {
log.Fatalf("failed to receive: %v", err)
return nil, fmt.Errorf("failed to receive from cmcd: %w", err)
}

if mtype == api.TypeError {
resp := new(api.SocketError)
err = cbor.Unmarshal(payload, resp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err)
}
return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg)
} else if mtype != api.TypeAttest {
return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype))
}

resp := new(api.AttestationResponse)
err = cbor.Unmarshal(payload, resp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal body: %w", err)
return nil, fmt.Errorf("failed to unmarshal cmcd attestation response body: %w", err)
}

return resp.AttestationReport, nil
Expand All @@ -85,7 +96,7 @@ func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) {
func (a SocketApi) verifyAR(chbindings, report []byte, cc CmcConfig) error {

// Establish connection
log.Tracef("Contacting cmcd via %v on %v", cc.Network, cc.CmcAddr)
log.Tracef("Sending verification request to cmcd via %v on %v", cc.Network, cc.CmcAddr)
conn, err := net.Dial(cc.Network, cc.CmcAddr)
if err != nil {
return fmt.Errorf("error dialing: %w", err)
Expand All @@ -106,20 +117,31 @@ func (a SocketApi) verifyAR(chbindings, report []byte, cc CmcConfig) error {
// Perform Verify request
err = api.Send(conn, payload, api.TypeVerify)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
return fmt.Errorf("failed to send request to cmcd: %w", err)
}

// Read reply
payload, _, err = api.Receive(conn)
payload, mtype, err := api.Receive(conn)
if err != nil {
log.Fatalf("failed to receive: %v", err)
return fmt.Errorf("failed to receive from cmcd: %v", err)
}

if mtype == api.TypeError {
resp := new(api.SocketError)
err = cbor.Unmarshal(payload, resp)
if err != nil {
return fmt.Errorf("failed to unmarshal error response from cmcd: %w", err)
}
return fmt.Errorf("received error from cmcd: %v", resp.Msg)
} else if mtype != api.TypeVerify {
return fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype))
}

// Unmarshal verify response
var verifyResp api.VerificationResponse
err = cbor.Unmarshal(payload, &verifyResp)
if err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
return fmt.Errorf("failed to unmarshal cmcd verify response: %w", err)
}

// Parse VerificationResult
Expand Down Expand Up @@ -180,11 +202,22 @@ func (a SocketApi) fetchSignature(cc CmcConfig, digest []byte, opts crypto.Signe
}

// Read reply
payload, _, err = api.Receive(conn)
payload, mtype, err := api.Receive(conn)
if err != nil {
log.Fatalf("failed to receive: %v", err)
}

if mtype == api.TypeError {
resp := new(api.SocketError)
err = cbor.Unmarshal(payload, resp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err)
}
return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg)
} else if mtype != api.TypeTLSSign {
return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype))
}

// Unmarshal sign response
var signResp api.TLSSignResponse
err = cbor.Unmarshal(payload, &signResp)
Expand Down Expand Up @@ -223,11 +256,22 @@ func (a SocketApi) fetchCerts(cc CmcConfig) ([][]byte, error) {
}

// Read reply
payload, _, err = api.Receive(conn)
payload, mtype, err := api.Receive(conn)
if err != nil {
log.Fatalf("failed to receive: %v", err)
}

if mtype == api.TypeError {
resp := new(api.SocketError)
err = cbor.Unmarshal(payload, resp)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err)
}
return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg)
} else if mtype != api.TypeTLSCert {
return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype))
}

// Unmarshal cert response
var certResp api.TLSCertResponse
err = cbor.Unmarshal(payload, &certResp)
Expand Down

0 comments on commit 837ceac

Please sign in to comment.