Skip to content

Commit

Permalink
Use io.ReadFull to ensure we read our ideal chunk size. (#3138)
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerwilliams authored Dec 21, 2022
1 parent 6cc1934 commit a70dbc1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
32 changes: 19 additions & 13 deletions enterprise/server/util/cacheproxy/cacheproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,6 @@ func (c *CacheProxy) GetMulti(ctx context.Context, req *dcpb.GetMultiRequest) (*
return rsp, nil
}

type streamWriter struct {
stream dcpb.DistributedCache_ReadServer
}

func (w *streamWriter) Write(buf []byte) (int, error) {
err := w.stream.Send(&dcpb.ReadResponse{
Data: buf,
})
return len(buf), err
}

func (c *CacheProxy) Read(req *dcpb.ReadRequest, stream dcpb.DistributedCache_ReadServer) error {
ctx, err := c.readWriteContext(stream.Context())
if err != nil {
Expand All @@ -326,8 +315,25 @@ func (c *CacheProxy) Read(req *dcpb.ReadRequest, stream dcpb.DistributedCache_Re
bufSize = resourceSize
}
copyBuf := c.bufferPool.Get(bufSize)
_, err = io.CopyBuffer(&streamWriter{stream}, reader, copyBuf[:bufSize])
c.bufferPool.Put(copyBuf)
defer c.bufferPool.Put(copyBuf)

buf := copyBuf[:bufSize]
for {
n, err := io.ReadFull(reader, buf)
if err == io.EOF {
break
} else if err == io.ErrUnexpectedEOF {
if err := stream.Send(&dcpb.ReadResponse{Data: buf[:n]}); err != nil {
return err
}
} else {
if err := stream.Send(&dcpb.ReadResponse{Data: buf}); err != nil {
return err
}
continue
}
}

c.log.Debugf("Read(%q) succeeded (user prefix: %s)", ResourceIsolationString(rn), up)
return err
}
Expand Down
33 changes: 20 additions & 13 deletions server/remote_cache/byte_stream_server/byte_stream_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,6 @@ func checkReadPreconditions(req *bspb.ReadRequest) error {
return nil
}

type streamWriter struct {
stream bspb.ByteStream_ReadServer
}

func (w *streamWriter) Write(buf []byte) (int, error) {
err := w.stream.Send(&bspb.ReadResponse{
Data: buf,
})
return len(buf), err
}

// `Read()` is used to retrieve the contents of a resource as a sequence
// of bytes. The bytes are returned in a sequence of responses, and the
// responses are delivered as the results of a server-side streaming FUNC (S *BYTESTREAMSERVER).
Expand Down Expand Up @@ -152,8 +141,26 @@ func (s *ByteStreamServer) Read(req *bspb.ReadRequest, stream bspb.ByteStream_Re

copyBuf := s.bufferPool.Get(bufSize)
defer s.bufferPool.Put(copyBuf)
n, err := io.CopyBuffer(&streamWriter{stream}, reader, copyBuf[:bufSize])
downloadTracker.CloseWithBytesTransferred(n, r.GetCompressor())

buf := copyBuf[:bufSize]
bytesTransferred := 0
for {
n, err := io.ReadFull(reader, buf)
bytesTransferred += n
if err == io.EOF {
break
} else if err == io.ErrUnexpectedEOF {
if err := stream.Send(&bspb.ReadResponse{Data: buf[:n]}); err != nil {
return err
}
} else {
if err := stream.Send(&bspb.ReadResponse{Data: buf}); err != nil {
return err
}
continue
}
}
downloadTracker.CloseWithBytesTransferred(int64(bytesTransferred), r.GetCompressor())
return err
}

Expand Down

0 comments on commit a70dbc1

Please sign in to comment.