diff --git a/api/api.go b/api/api.go index 675eedf1..6db92e73 100644 --- a/api/api.go +++ b/api/api.go @@ -18,6 +18,7 @@ package api import ( + "bytes" "crypto" "crypto/rsa" "encoding/binary" @@ -194,9 +195,13 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) { // payload []byte -> CBOR-encoded payload func Receive(conn net.Conn) ([]byte, uint32, error) { - err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen) - if err != nil { - return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err) + // If unix domain sockets are used, set the write buffer size + _, ok := conn.(*net.UnixConn) + if ok { + err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen) + if err != nil { + return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err) + } } // Read header @@ -213,7 +218,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { } // Decode header to get length and type - payloadLen := binary.BigEndian.Uint32(buf[0:4]) + payloadLen := int(binary.BigEndian.Uint32(buf[0:4])) msgType := binary.BigEndian.Uint32(buf[4:8]) if payloadLen > MaxMsgLen { @@ -224,19 +229,29 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen) // Read payload - payload := make([]byte, payloadLen) - n, err = conn.Read(payload) - if err != nil { - return nil, 0, fmt.Errorf("failed to read payload: %w", err) - } - if uint32(n) != payloadLen { - return nil, 0, fmt.Errorf("failed to read payload (received %v, expected %v bytes)", - n, payloadLen) + payload := bytes.NewBuffer(nil) + received := 0 + for { + chunk := make([]byte, 128*1024) + n, err = conn.Read(chunk) + if err != nil { + return nil, 0, fmt.Errorf("failed to read payload: %w", err) + } + received += n + payload.Write(chunk[:n]) + + log.Tracef("Received chunk of %v bytes\n", n) + + if received == payloadLen { + break + } } + log.Tracef("Received payload length %v", payloadLen) + if msgType == TypeError { resp := new(SocketError) - err = cbor.Unmarshal(payload, resp) + err = cbor.Unmarshal(payload.Bytes(), resp) if err != nil { return nil, 0, fmt.Errorf("failed to unmarshal error response") } else { @@ -244,7 +259,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) { } } - return payload, msgType, nil + return payload.Bytes(), msgType, nil } // Send sends data to a socket with the following format @@ -259,9 +274,13 @@ func Send(conn net.Conn, payload []byte, t uint32) error { len(payload), MaxMsgLen) } - err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen) - if err != nil { - return fmt.Errorf("failed to socket write buffer size %v", err) + // If unix domain sockets are used, set the write buffer size + _, ok := conn.(*net.UnixConn) + if ok { + err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen) + if err != nil { + return fmt.Errorf("failed to socket write buffer size %v", err) + } } buf := make([]byte, 8)