diff --git a/connect.go b/connect.go index bd8490e5..ecc0f7ce 100644 --- a/connect.go +++ b/connect.go @@ -431,9 +431,10 @@ func receiveUnaryMessage[T any](conn receiveConn, initializer maybeInitializer, if err := initializer.maybe(conn.Spec(), &msg2); err != nil { return nil, err } - if err := conn.Receive(&msg2); err == nil { - return nil, NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) - } else if err != nil && !errors.Is(err, io.EOF) { + if err := conn.Receive(&msg2); !errors.Is(err, io.EOF) { + if err == nil { + err = NewError(CodeUnimplemented, fmt.Errorf("unary %s has multiple messages", what)) + } return nil, err } return &msg, nil diff --git a/protocol_connect.go b/protocol_connect.go index d478d634..e3c5e4a5 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -686,7 +686,6 @@ type connectUnaryHandlerConn struct { marshaler connectUnaryMarshaler unmarshaler connectUnaryUnmarshaler responseTrailer http.Header - wroteBody bool } func (hc *connectUnaryHandlerConn) Spec() Spec { @@ -709,8 +708,7 @@ func (hc *connectUnaryHandlerConn) RequestHeader() http.Header { } func (hc *connectUnaryHandlerConn) Send(msg any) error { - hc.wroteBody = true - hc.writeResponseHeader(nil /* error */) + hc.mergeResponseHeader(nil /* error */) if err := hc.marshaler.Marshal(msg); err != nil { return err } @@ -726,8 +724,8 @@ func (hc *connectUnaryHandlerConn) ResponseTrailer() http.Header { } func (hc *connectUnaryHandlerConn) Close(err error) error { - if !hc.wroteBody { - hc.writeResponseHeader(err) + if !hc.marshaler.wroteHeader { + hc.mergeResponseHeader(err) // If the handler received a GET request and the resource hasn't changed, // return a 304. if len(hc.peer.Query) > 0 && IsNotModifiedError(err) { @@ -735,7 +733,7 @@ func (hc *connectUnaryHandlerConn) Close(err error) error { return hc.request.Body.Close() } } - if err == nil { + if err == nil || hc.marshaler.wroteHeader { return hc.request.Body.Close() } // In unary Connect, errors always use application/json. @@ -757,7 +755,7 @@ func (hc *connectUnaryHandlerConn) getHTTPMethod() string { return hc.request.Method } -func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { +func (hc *connectUnaryHandlerConn) mergeResponseHeader(err error) { header := hc.responseWriter.Header() if hc.request.Method == http.MethodGet { // The response content varies depending on the compression that the client @@ -923,6 +921,7 @@ type connectUnaryMarshaler struct { bufferPool *bufferPool header http.Header sendMaxBytes int + wroteHeader bool } func (m *connectUnaryMarshaler) Marshal(message any) *Error { @@ -961,6 +960,7 @@ func (m *connectUnaryMarshaler) Marshal(message any) *Error { } func (m *connectUnaryMarshaler) write(data []byte) *Error { + m.wroteHeader = true payload := bytes.NewReader(data) if _, err := m.sender.Send(payload); err != nil { err = wrapIfContextError(err)