Skip to content

Commit

Permalink
Enforce discard limits on readers
Browse files Browse the repository at this point in the history
This enforces limits on discard to avoid unbounded reads. Where
resources were already exhausted no further reads are done and discards
have been removed. These discards were an optimization to reuse
connections. When a stream is partially read all subsequent reads will
now return EOF errors to avoid reading in a corrupted state.
  • Loading branch information
emcfarlane committed Oct 25, 2024
1 parent 145b279 commit 73a761b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 48 deletions.
12 changes: 4 additions & 8 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,13 @@ func (c *compressionPool) Decompress(dst *bytes.Buffer, src *bytes.Buffer, readM
}
return errorf(CodeInvalidArgument, "decompress: %w", err)
}
if readMaxBytes > 0 && bytesRead > readMaxBytes {
discardedBytes, err := io.Copy(io.Discard, decompressor)
_ = c.putDecompressor(decompressor)
if err != nil {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, readMaxBytes)
}
if err := c.putDecompressor(decompressor); err != nil {
return errorf(CodeUnknown, "recycle decompressor: %w", err)
}
if readMaxBytes > 0 && bytesRead > readMaxBytes {
// Resource is exhausted, fail fast without reading more data from the reader.
return errorf(CodeResourceExhausted, "decompressed message size is larger than configured max %d", readMaxBytes)
}
return nil
}

Expand Down
7 changes: 1 addition & 6 deletions connect_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,6 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
})
t.Run("read_max_large", func(t *testing.T) {
t.Parallel()
Expand All @@ -1206,16 +1205,14 @@ func TestHandlerWithReadMaxBytes(t *testing.T) {
}
// Serializes to much larger than readMaxBytes (5 MiB)
pingRequest := &pingv1.PingRequest{Text: strings.Repeat("abcde", 1024*1024)}
expectedSize := proto.Size(pingRequest)
// With gzip request compression, the error should indicate the envelope size (before decompression) is too large.
if compressed {
expectedSize = gzipCompressedSize(t, pingRequest)
expectedSize := gzipCompressedSize(t, pingRequest)
assert.True(t, expectedSize > readMaxBytes, assert.Sprintf("expected compressed size %d > %d", expectedSize, readMaxBytes))
}
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
})
}
newHTTP2Server := func(t *testing.T) *memhttp.Server {
Expand Down Expand Up @@ -1378,7 +1375,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.True(t, strings.HasSuffix(err.Error(), fmt.Sprintf("message size %d is larger than configured max %d", proto.Size(pingRequest), readMaxBytes)))
})
t.Run("read_max_large", func(t *testing.T) {
t.Parallel()
Expand All @@ -1397,7 +1393,6 @@ func TestClientWithReadMaxBytes(t *testing.T) {
_, err := client.Ping(context.Background(), connect.NewRequest(pingRequest))
assert.NotNil(t, err, assert.Sprintf("expected non-nil error for large message"))
assert.Equal(t, connect.CodeOf(err), connect.CodeResourceExhausted)
assert.Equal(t, err.Error(), fmt.Sprintf("resource_exhausted: message size %d is larger than configured max %d", expectedSize, readMaxBytes))
})
}
t.Run("connect", func(t *testing.T) {
Expand Down
33 changes: 14 additions & 19 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,13 @@ type envelopeReader struct {
compressionPool *compressionPool
bufferPool *bufferPool
readMaxBytes int
isEOF bool
}

func (r *envelopeReader) Unmarshal(message any) *Error {
if r.isEOF {
return NewError(CodeInternal, io.EOF)
}
buffer := r.bufferPool.Get()
var dontRelease *bytes.Buffer
defer func() {
Expand All @@ -240,25 +244,20 @@ func (r *envelopeReader) Unmarshal(message any) *Error {
}()

env := &envelope{Data: buffer}
err := r.Read(env)
switch {
case err == nil && env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil:
if err := r.Read(env); err != nil {
// Mark the reader as EOF so that subsequent reads return EOF.
r.isEOF = true
return err
}
if env.IsSet(flagEnvelopeCompressed) && r.compressionPool == nil {
return errorf(
CodeInternal,
"protocol error: sent compressed message without compression support",
)
case err == nil &&
(env.Flags == 0 || env.Flags == flagEnvelopeCompressed) &&
env.Data.Len() == 0:
} else if (env.Flags == 0 || env.Flags == flagEnvelopeCompressed) && env.Data.Len() == 0 {
// This is a standard message (because none of the top 7 bits are set) and
// there's no data, so the zero value of the message is correct.
return nil
case err != nil && errors.Is(err, io.EOF):
// The stream has ended. Propagate the EOF to the caller.
return err
case err != nil:
// Something's wrong.
return err
}

data := env.Data
Expand Down Expand Up @@ -317,7 +316,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
return NewError(CodeInternal, err)
}
err = wrapIfMaxBytesError(err, "read 5 byte message prefix")
err = wrapIfContextDone(r.ctx, err)
Expand All @@ -332,12 +331,8 @@ func (r *envelopeReader) Read(env *envelope) *Error {
}
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
n, err := io.CopyN(io.Discard, r.reader, size)
r.bytesRead += n
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", r.readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", size, r.readMaxBytes)
// Resource is exhausted, fail fast without reading more data from the stream.
return errorf(CodeResourceExhausted, "received message size %d is larger than configured max %d", size, r.readMaxBytes)
}
// We've read the prefix, so we know how many bytes to expect.
// CopyN will return an error if it doesn't read the requested
Expand Down
8 changes: 4 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,12 @@ func isCommaOrSpace(c rune) bool {
}

func discard(reader io.Reader) (int64, error) {
if lr, ok := reader.(*io.LimitedReader); ok {
return io.Copy(io.Discard, lr)
}
// We don't want to get stuck throwing data away forever, so limit how much
// we're willing to do here.
lr := &io.LimitedReader{R: reader, N: discardLimit}
lr, ok := reader.(*io.LimitedReader)
if !ok {
lr = &io.LimitedReader{R: reader, N: discardLimit}
}
return io.Copy(io.Discard, lr)
}

Expand Down
14 changes: 5 additions & 9 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -1088,19 +1088,19 @@ type connectUnaryUnmarshaler struct {
codec Codec
compressionPool *compressionPool
bufferPool *bufferPool
alreadyRead bool
readMaxBytes int
isEOF bool
}

func (u *connectUnaryUnmarshaler) Unmarshal(message any) *Error {
return u.UnmarshalFunc(message, u.codec.Unmarshal)
}

func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]byte, any) error) *Error {
if u.alreadyRead {
if u.isEOF {
return NewError(CodeInternal, io.EOF)
}
u.alreadyRead = true
u.isEOF = true
data := u.bufferPool.Get()
defer u.bufferPool.Put(data)
reader := u.reader
Expand All @@ -1118,12 +1118,8 @@ func (u *connectUnaryUnmarshaler) UnmarshalFunc(message any, unmarshal func([]by
return errorf(CodeUnknown, "read message: %w", err)
}
if u.readMaxBytes > 0 && bytesRead > int64(u.readMaxBytes) {
// Attempt to read to end in order to allow connection re-use
discardedBytes, err := io.Copy(io.Discard, u.reader)
if err != nil {
return errorf(CodeResourceExhausted, "message is larger than configured max %d - unable to determine message size: %w", u.readMaxBytes, err)
}
return errorf(CodeResourceExhausted, "message size %d is larger than configured max %d", bytesRead+discardedBytes, u.readMaxBytes)
// Resource is exhausted, fail fast without reading more data from the stream.
return errorf(CodeResourceExhausted, "message size is larger than configured max %d", u.readMaxBytes)
}
if data.Len() > 0 && u.compressionPool != nil {
decompressed := u.bufferPool.Get()
Expand Down
4 changes: 2 additions & 2 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ func (g *grpcClient) NewConn(
}
} else {
conn.readTrailers = func(_ *grpcUnmarshaler, call *duplexHTTPCall) http.Header {
// To access HTTP trailers, we need to read the body to EOF.
_, _ = discard(call)
// Caller must guarantee the body is read to EOF to access
// trailers.
return call.ResponseTrailer()
}
}
Expand Down

0 comments on commit 73a761b

Please sign in to comment.