diff --git a/api/api.go b/api/api.go index 65336b8d..675eedf1 100644 --- a/api/api.go +++ b/api/api.go @@ -72,6 +72,11 @@ type TLSCertResponse struct { Certificate [][]byte `json:"certificate" cbor:"0,keyasint"` } +const ( + // Set maximum message length to 10 MB + MaxMsgLen = 1024 * 1024 * 10 +) + type HashFunction int32 const ( @@ -188,8 +193,17 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) { // Type uint32 -> Type of the payload // payload []byte -> CBOR-encoded payload func Receive(conn net.Conn) ([]byte, uint32, error) { + + err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen) + if err != nil { + return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err) + } + // Read header buf := make([]byte, 8) + + log.Tracef("Reading header length %v", len(buf)) + n, err := conn.Read(buf) if err != nil { return nil, 0, fmt.Errorf("failed to read header: %w", err) @@ -199,18 +213,25 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { } // Decode header to get length and type - len := binary.BigEndian.Uint32(buf[0:4]) + payloadLen := binary.BigEndian.Uint32(buf[0:4]) msgType := binary.BigEndian.Uint32(buf[4:8]) + if payloadLen > MaxMsgLen { + return nil, 0, fmt.Errorf("cannot receive: payload size %v exceeds maximum size %v", + payloadLen, MaxMsgLen) + } + + log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen) + // Read payload - payload := make([]byte, len) + payload := make([]byte, payloadLen) n, err = conn.Read(payload) if err != nil { return nil, 0, fmt.Errorf("failed to read payload: %w", err) } - if uint32(n) != len { + if uint32(n) != payloadLen { return nil, 0, fmt.Errorf("failed to read payload (received %v, expected %v bytes)", - n, len) + n, payloadLen) } if msgType == TypeError { @@ -232,9 +253,23 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { // Type uint32 -> Type of the payload // payload []byte -> CBOR-encoded payload func Send(conn net.Conn, payload []byte, t uint32) error { + + if len(payload) > MaxMsgLen { + return fmt.Errorf("cannot send: payload size %v exceeds maximum size %v", + len(payload), MaxMsgLen) + } + + err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen) + if err != nil { + return fmt.Errorf("failed to socket write buffer size %v", err) + } + buf := make([]byte, 8) binary.BigEndian.PutUint32(buf[0:4], uint32(len(payload))) binary.BigEndian.PutUint32(buf[4:8], t) + + log.Tracef("Sending header length %v", len(buf)) + n, err := conn.Write(buf) if err != nil { return fmt.Errorf("failed to send header: %w", err) @@ -242,6 +277,9 @@ func Send(conn net.Conn, payload []byte, t uint32) error { if n != len(buf) { return fmt.Errorf("could only send %v of %v bytes", n, len(buf)) } + + log.Tracef("Sending payload type %v length %v", t, uint32(len(payload))) + n, err = conn.Write(payload) if err != nil { return fmt.Errorf("failed to send response: %w", err) diff --git a/cmcd/socket.go b/cmcd/socket.go index c3e0c889..1e9063b2 100644 --- a/cmcd/socket.go +++ b/cmcd/socket.go @@ -144,7 +144,10 @@ func attest(conn net.Conn, payload []byte, cmc *cmc.Cmc) { return } - api.Send(conn, data, api.TypeAttest) + err = api.Send(conn, data, api.TypeAttest) + if err != nil { + api.SendError(conn, "failed to send: %v", err) + } log.Debug("Prover: Finished") } @@ -181,7 +184,10 @@ func verify(conn net.Conn, payload []byte, cmc *cmc.Cmc) { return } - api.Send(conn, data, api.TypeVerify) + err = api.Send(conn, data, api.TypeVerify) + if err != nil { + api.SendError(conn, "failed to send: %v", err) + } log.Debug("Verifier: Finished") } @@ -235,7 +241,10 @@ func tlssign(conn net.Conn, payload []byte, cmc *cmc.Cmc) { return } - api.Send(conn, data, api.TypeTLSSign) + err = api.Send(conn, data, api.TypeTLSSign) + if err != nil { + api.SendError(conn, "failed to send: %v", err) + } log.Debug("Performed signing") } @@ -276,7 +285,10 @@ func tlscert(conn net.Conn, payload []byte, cmc *cmc.Cmc) { return } - api.Send(conn, data, api.TypeTLSCert) + err = api.Send(conn, data, api.TypeTLSCert) + if err != nil { + api.SendError(conn, "failed to send: %v", err) + } log.Debug("Obtained TLS cert") } diff --git a/testtool/coap.go b/testtool/coap.go index 8e9b080d..beb3b662 100644 --- a/testtool/coap.go +++ b/testtool/coap.go @@ -44,6 +44,8 @@ func init() { func (a CoapApi) generate(c *config) { + log.Tracef("Connecting via CoAP to %v", c.CmcAddr) + // Establish connection conn, err := udp.Dial(c.CmcAddr) if err != nil { @@ -96,14 +98,14 @@ func (a CoapApi) generate(c *config) { if err != nil { log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err) } - fmt.Println("Wrote attestation report: ", c.ReportFile) + log.Infof("Wrote attestation report: %v", c.ReportFile) // Save the nonce for the verifier os.WriteFile(c.NonceFile, nonce, 0644) if err != nil { log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err) } - fmt.Println("Wrote nonce: ", c.NonceFile) + log.Infof("Wrote nonce: %v", c.NonceFile) } @@ -168,6 +170,9 @@ func (a CoapApi) iothub(c *config) { func verifyInternal(addr string, req *api.VerificationRequest, ) (*api.VerificationResponse, error) { + + log.Tracef("Connecting via CoAP to %v", addr) + // Establish connection conn, err := udp.Dial(addr) if err != nil { diff --git a/testtool/grpc.go b/testtool/grpc.go index d23fa35d..567c0716 100644 --- a/testtool/grpc.go +++ b/testtool/grpc.go @@ -21,7 +21,6 @@ package main import ( "context" "crypto/rand" - "fmt" "os" "time" @@ -46,6 +45,8 @@ func (a GrpcApi) generate(c *config) { ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() + log.Tracef("Connecting via gRPC to %v", c.CmcAddr) + conn, err := grpc.DialContext(ctx, c.CmcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil { log.Fatalf("Failed to connect to cmcd: %v", err) @@ -76,14 +77,14 @@ func (a GrpcApi) generate(c *config) { if err != nil { log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err) } - fmt.Println("Wrote attestation report: ", c.ReportFile) + log.Infof("Wrote attestation report: %v", c.ReportFile) // Save the nonce for the verifier os.WriteFile(c.NonceFile, nonce, 0644) if err != nil { log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err) } - fmt.Println("Wrote nonce: ", c.NonceFile) + log.Infof("Wrote nonce: %v", c.NonceFile) } @@ -93,6 +94,8 @@ func (a GrpcApi) verify(c *config) { ctx, cancel := context.WithTimeout(context.Background(), timeoutSec*time.Second) defer cancel() + log.Tracef("Connecting via gRPC to %v", c.CmcAddr) + conn, err := grpc.DialContext(ctx, c.CmcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock()) if err != nil { log.Fatalf("Failed to connect to cmcd: %v", err) diff --git a/testtool/publish.go b/testtool/publish.go index 2065c16b..772a20e3 100644 --- a/testtool/publish.go +++ b/testtool/publish.go @@ -111,7 +111,7 @@ func saveResult(file, addr string, result []byte) error { // Save the Attestation Result to file if file != "" { os.WriteFile(file, out.Bytes(), 0644) - fmt.Println("Wrote file ", file) + log.Infof("Wrote file %v", file) } else { log.Debug("No config file specified: will not save attestation report") } diff --git a/testtool/socket.go b/testtool/socket.go index 4e8584c3..5f969266 100644 --- a/testtool/socket.go +++ b/testtool/socket.go @@ -40,6 +40,8 @@ func init() { func (a SocketApi) generate(c *config) { + log.Tracef("Connecting via %v socket to %v", c.Network, c.CmcAddr) + // Establish connection conn, err := net.Dial(c.Network, c.CmcAddr) if err != nil { @@ -88,14 +90,14 @@ func (a SocketApi) generate(c *config) { if err != nil { log.Fatalf("Failed to save attestation report as %v: %v", c.ReportFile, err) } - fmt.Println("Wrote attestation report: ", c.ReportFile) + log.Infof("Wrote attestation report: %v", c.ReportFile) // Save the nonce for the verifier os.WriteFile(c.NonceFile, nonce, 0644) if err != nil { log.Fatalf("Failed to save nonce as %v: %v", c.NonceFile, err) } - fmt.Println("Wrote nonce: ", c.NonceFile) + log.Infof("Wrote nonce: %v", c.NonceFile) } @@ -156,6 +158,9 @@ func (a SocketApi) iothub(c *config) { func verifySocketRequest(network, addr string, req *api.VerificationRequest, ) (*api.VerificationResponse, error) { + + log.Tracef("Connecting via %v socket to %v", network, addr) + // Establish connection conn, err := net.Dial(network, addr) if err != nil {