From 4271830b1a783493556d9770b458755712472745 Mon Sep 17 00:00:00 2001 From: Chris Schinnerl Date: Wed, 6 Dec 2023 13:21:00 +0100 Subject: [PATCH] client: extract getting the size from the seeker into helper function --- worker/client/client.go | 44 +++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/worker/client/client.go b/worker/client/client.go index f275bcdaf..af435c86e 100644 --- a/worker/client/client.go +++ b/worker/client/client.go @@ -196,18 +196,8 @@ func (c *Client) UploadMultipartUploadPart(ctx context.Context, r io.Reader, buc req.SetBasicAuth("", c.c.WithContext(ctx).Password) if opts.ContentLength != 0 { req.ContentLength = opts.ContentLength - } else { - if s, ok := r.(io.Seeker); ok { - length, err := s.Seek(0, io.SeekEnd) - if err != nil { - return nil, err - } - _, err = s.Seek(0, io.SeekStart) - if err != nil { - return nil, err - } - req.ContentLength = length - } + } else if req.ContentLength, err = sizeFromSeeker(r); err != nil { + return nil, fmt.Errorf("failed to get content length from seeker: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -242,18 +232,8 @@ func (c *Client) UploadObject(ctx context.Context, r io.Reader, bucket, path str req.SetBasicAuth("", c.c.WithContext(ctx).Password) if opts.ContentLength != 0 { req.ContentLength = opts.ContentLength - } else { - if s, ok := r.(io.Seeker); ok { - length, err := s.Seek(0, io.SeekEnd) - if err != nil { - return nil, err - } - _, err = s.Seek(0, io.SeekStart) - if err != nil { - return nil, err - } - req.ContentLength = length - } + } else if req.ContentLength, err = sizeFromSeeker(r); err != nil { + return nil, fmt.Errorf("failed to get content length from seeker: %w", err) } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -299,3 +279,19 @@ func (c *Client) object(ctx context.Context, bucket, path string, opts api.Downl } return resp.Body, resp.Header, err } + +func sizeFromSeeker(r io.Reader) (int64, error) { + s, ok := r.(io.Seeker) + if !ok { + return 0, nil + } + size, err := s.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + _, err = s.Seek(0, io.SeekStart) + if err != nil { + return 0, err + } + return size, nil +}