Skip to content

Commit

Permalink
api: fix socket interface
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Ott <[email protected]>
  • Loading branch information
smo4201 committed Mar 4, 2024
1 parent fe51b99 commit d2e890d
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package api

import (
"bytes"
"crypto"
"crypto/rsa"
"encoding/binary"
Expand Down Expand Up @@ -194,9 +195,13 @@ func SignerOptsToHash(opts crypto.SignerOpts) (HashFunction, error) {
// payload []byte -> CBOR-encoded payload
func Receive(conn net.Conn) ([]byte, uint32, error) {

err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen)
if err != nil {
return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err)
// If unix domain sockets are used, set the write buffer size
_, ok := conn.(*net.UnixConn)
if ok {
err := conn.(*net.UnixConn).SetReadBuffer(MaxMsgLen)
if err != nil {
return nil, 0, fmt.Errorf("failed to socket write buffer size %v", err)
}
}

// Read header
Expand All @@ -213,7 +218,7 @@ func Receive(conn net.Conn) ([]byte, uint32, error) {
}

// Decode header to get length and type
payloadLen := binary.BigEndian.Uint32(buf[0:4])
payloadLen := int(binary.BigEndian.Uint32(buf[0:4]))
msgType := binary.BigEndian.Uint32(buf[4:8])

if payloadLen > MaxMsgLen {
Expand All @@ -224,27 +229,37 @@ func Receive(conn net.Conn) ([]byte, uint32, error) {
log.Tracef("Decoded header. Expecting type %v, length %v", msgType, payloadLen)

// Read payload
payload := make([]byte, payloadLen)
n, err = conn.Read(payload)
if err != nil {
return nil, 0, fmt.Errorf("failed to read payload: %w", err)
}
if uint32(n) != payloadLen {
return nil, 0, fmt.Errorf("failed to read payload (received %v, expected %v bytes)",
n, payloadLen)
payload := bytes.NewBuffer(nil)
received := 0
for {
chunk := make([]byte, 128*1024)
n, err = conn.Read(chunk)
if err != nil {
return nil, 0, fmt.Errorf("failed to read payload: %w", err)
}
received += n
payload.Write(chunk[:n])

log.Tracef("Received chunk of %v bytes\n", n)

if received == payloadLen {
break
}
}

log.Tracef("Received payload length %v", payloadLen)

if msgType == TypeError {
resp := new(SocketError)
err = cbor.Unmarshal(payload, resp)
err = cbor.Unmarshal(payload.Bytes(), resp)
if err != nil {
return nil, 0, fmt.Errorf("failed to unmarshal error response")
} else {
return nil, 0, fmt.Errorf("server responded with error: %v", resp.Msg)
}
}

return payload, msgType, nil
return payload.Bytes(), msgType, nil
}

// Send sends data to a socket with the following format
Expand All @@ -259,9 +274,13 @@ func Send(conn net.Conn, payload []byte, t uint32) error {
len(payload), MaxMsgLen)
}

err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen)
if err != nil {
return fmt.Errorf("failed to socket write buffer size %v", err)
// If unix domain sockets are used, set the write buffer size
_, ok := conn.(*net.UnixConn)
if ok {
err := conn.(*net.UnixConn).SetWriteBuffer(MaxMsgLen)
if err != nil {
return fmt.Errorf("failed to socket write buffer size %v", err)
}
}

buf := make([]byte, 8)
Expand Down

0 comments on commit d2e890d

Please sign in to comment.