diff --git a/api/api.go b/api/api.go index cf2b9e8b..d420534c 100644 --- a/api/api.go +++ b/api/api.go @@ -124,6 +124,25 @@ const ( TypeTLSCert uint32 = 5 ) +func TypeToString(t uint32) string { + switch t { + case TypeError: + return "Error" + case TypeAttest: + return "Attest" + case TypeVerify: + return "Verify" + case TypeMeasure: + return "Measure" + case TypeTLSSign: + return "TLSSign" + case TypeTLSCert: + return "TLSCert" + default: + return "Unknown" + } +} + // Converts Protobuf hashtype to crypto.SignerOpts func HashToSignerOpts(hashtype HashFunction, pssOpts *PSSOptions) (crypto.SignerOpts, error) { var hash crypto.Hash @@ -236,7 +255,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { payloadLen, MaxMsgLen) } - log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen) + log.Tracef("Decoded header. Type %v, length %v", TypeToString(msgType), payloadLen) // Read payload payload := bytes.NewBuffer(nil) @@ -297,7 +316,7 @@ func Send(conn net.Conn, payload []byte, t uint32) error { return fmt.Errorf("could only send %v of %v bytes", n, len(buf)) } - log.Tracef("Sending payload type %v length %v", t, uint32(len(payload))) + log.Tracef("Sending payload type %v length %v", TypeToString(t), uint32(len(payload))) n, err = conn.Write(payload) if err != nil { diff --git a/attestationreport/json.go b/attestationreport/json.go index 18e53163..d8dcdf10 100644 --- a/attestationreport/json.go +++ b/attestationreport/json.go @@ -338,7 +338,7 @@ func (hws *hwSigner) SignPayload(payload []byte, alg jose.SignatureAlgorithm) ([ } hashed := hasher.Sum(nil) - // sign payload + // Sign payload switch alg { case jose.ES256, jose.ES384, jose.ES512: // Obtain signature diff --git a/attestedtls/attestation.go b/attestedtls/attestation.go index 982cd850..e122b099 100644 --- a/attestedtls/attestation.go +++ b/attestedtls/attestation.go @@ -56,7 +56,7 @@ func attestDialer(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { if err != nil { return fmt.Errorf("failed to send skip client Attestation: %w", err) } - log.Debug("Skipping client-side attestation") + log.Debug("Skipping client-side attestation: no attestation report generation required") } // Fetch attestation report from listener @@ -96,10 +96,10 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { // optional: attest server if cc.Attest == Attest_Mutual || cc.Attest == Attest_Server { // Obtain own attestation report from local cmcd - log.Trace("Attesting the Server") + log.Trace("Listener: Fetching attestation report from cmcd") resp, err := cc.CmcApi.obtainAR(cc, chbindings) if err != nil { - return fmt.Errorf("could not obtain AR of Listener : %w", err) + return fmt.Errorf("could not obtain listener attestation report: %w", err) } // Send own attestation report to dialer. This is done asynchronously to @@ -132,7 +132,7 @@ func attestListener(conn *tls.Conn, chbindings []byte, cc CmcConfig) error { // optional: Wait for attestation report from client if cc.Attest == Attest_Mutual || cc.Attest == Attest_Client { // Verify AR from dialer with own channel bindings - log.Trace("Verifying attestation report from dialer...") + log.Trace("Listener: Verifying attestation report from dialer...") err = cc.CmcApi.verifyAR(chbindings, report, cc) if err != nil { return err diff --git a/attestedtls/libapi.go b/attestedtls/libapi.go index 7fe5e3c8..95d00080 100644 --- a/attestedtls/libapi.go +++ b/attestedtls/libapi.go @@ -57,7 +57,7 @@ func (a LibApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) { log.Debug("Prover: Signing Attestation Report") signedReport, err := generate.Sign(report, cc.Cmc.Drivers[0], cc.Cmc.Serializer) if err != nil { - return nil, errors.New("prover: failed to sign Attestion Report ") + return nil, fmt.Errorf("prover: failed to sign attestation reoprt: %w", err) } return signedReport, nil diff --git a/attestedtls/socket.go b/attestedtls/socket.go index d571a5be..bfa007f5 100644 --- a/attestedtls/socket.go +++ b/attestedtls/socket.go @@ -43,10 +43,10 @@ func init() { func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) { // Establish connection - log.Tracef("Contacting cmcd via %v on %v", cc.Network, cc.CmcAddr) + log.Tracef("Sending attestation request to cmcd via %v on %v", cc.Network, cc.CmcAddr) conn, err := net.Dial(cc.Network, cc.CmcAddr) if err != nil { - return nil, fmt.Errorf("error dialing: %w", err) + return nil, fmt.Errorf("error dialing cmcd: %w", err) } req := &api.AttestationRequest{ @@ -63,19 +63,30 @@ func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) { // Send request err = api.Send(conn, payload, api.TypeAttest) if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) + return nil, fmt.Errorf("failed to send request to cmcd: %w", err) } // Read reply - payload, _, err = api.Receive(conn) + payload, mtype, err := api.Receive(conn) if err != nil { - log.Fatalf("failed to receive: %v", err) + return nil, fmt.Errorf("failed to receive from cmcd: %w", err) + } + + if mtype == api.TypeError { + resp := new(api.SocketError) + err = cbor.Unmarshal(payload, resp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err) + } + return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg) + } else if mtype != api.TypeAttest { + return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype)) } resp := new(api.AttestationResponse) err = cbor.Unmarshal(payload, resp) if err != nil { - return nil, fmt.Errorf("failed to unmarshal body: %w", err) + return nil, fmt.Errorf("failed to unmarshal cmcd attestation response body: %w", err) } return resp.AttestationReport, nil @@ -85,7 +96,7 @@ func (a SocketApi) obtainAR(cc CmcConfig, chbindings []byte) ([]byte, error) { func (a SocketApi) verifyAR(chbindings, report []byte, cc CmcConfig) error { // Establish connection - log.Tracef("Contacting cmcd via %v on %v", cc.Network, cc.CmcAddr) + log.Tracef("Sending verification request to cmcd via %v on %v", cc.Network, cc.CmcAddr) conn, err := net.Dial(cc.Network, cc.CmcAddr) if err != nil { return fmt.Errorf("error dialing: %w", err) @@ -106,20 +117,31 @@ func (a SocketApi) verifyAR(chbindings, report []byte, cc CmcConfig) error { // Perform Verify request err = api.Send(conn, payload, api.TypeVerify) if err != nil { - return fmt.Errorf("failed to send request: %w", err) + return fmt.Errorf("failed to send request to cmcd: %w", err) } // Read reply - payload, _, err = api.Receive(conn) + payload, mtype, err := api.Receive(conn) if err != nil { - log.Fatalf("failed to receive: %v", err) + return fmt.Errorf("failed to receive from cmcd: %v", err) + } + + if mtype == api.TypeError { + resp := new(api.SocketError) + err = cbor.Unmarshal(payload, resp) + if err != nil { + return fmt.Errorf("failed to unmarshal error response from cmcd: %w", err) + } + return fmt.Errorf("received error from cmcd: %v", resp.Msg) + } else if mtype != api.TypeVerify { + return fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype)) } // Unmarshal verify response var verifyResp api.VerificationResponse err = cbor.Unmarshal(payload, &verifyResp) if err != nil { - return fmt.Errorf("failed to unmarshal response: %w", err) + return fmt.Errorf("failed to unmarshal cmcd verify response: %w", err) } // Parse VerificationResult @@ -180,11 +202,22 @@ func (a SocketApi) fetchSignature(cc CmcConfig, digest []byte, opts crypto.Signe } // Read reply - payload, _, err = api.Receive(conn) + payload, mtype, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + if mtype == api.TypeError { + resp := new(api.SocketError) + err = cbor.Unmarshal(payload, resp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err) + } + return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg) + } else if mtype != api.TypeTLSSign { + return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype)) + } + // Unmarshal sign response var signResp api.TLSSignResponse err = cbor.Unmarshal(payload, &signResp) @@ -223,11 +256,22 @@ func (a SocketApi) fetchCerts(cc CmcConfig) ([][]byte, error) { } // Read reply - payload, _, err = api.Receive(conn) + payload, mtype, err := api.Receive(conn) if err != nil { log.Fatalf("failed to receive: %v", err) } + if mtype == api.TypeError { + resp := new(api.SocketError) + err = cbor.Unmarshal(payload, resp) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal error response from cmcd: %w", err) + } + return nil, fmt.Errorf("received error from cmcd: %v", resp.Msg) + } else if mtype != api.TypeTLSCert { + return nil, fmt.Errorf("unexpected response type %v from cmcd", api.TypeToString(mtype)) + } + // Unmarshal cert response var certResp api.TLSCertResponse err = cbor.Unmarshal(payload, &certResp)