From 755bc7a9455f71978fa9ccf68324249f70919c1b Mon Sep 17 00:00:00 2001 From: Or Shachar Date: Tue, 21 Nov 2023 14:31:20 +0200 Subject: [PATCH] Single call to S3 (#2) --- cachers/s3.go | 218 ++++++++++++++++++++------------------------------ 1 file changed, 87 insertions(+), 131 deletions(-) diff --git a/cachers/s3.go b/cachers/s3.go index 1a2dac6..6fa1286 100644 --- a/cachers/s3.go +++ b/cachers/s3.go @@ -3,7 +3,6 @@ package cachers import ( "bytes" "context" - "encoding/json" "errors" "fmt" "io" @@ -18,6 +17,10 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" ) +const ( + OutputIdKey = "outputid" +) + type S3Cache struct { Bucket string cfg *aws.Config @@ -69,12 +72,12 @@ func NewS3Cache(bucketName string, cfg *aws.Config, cacheKey string, disk *DiskC return cache } -func (c *S3Cache) client() (*s3.Client, error) { - if c.s3Client != nil { - return c.s3Client, nil +func (s *S3Cache) client() (*s3.Client, error) { + if s.s3Client != nil { + return s.s3Client, nil } - c.s3Client = s3.NewFromConfig(*c.cfg) - return c.s3Client, nil + s.s3Client = s3.NewFromConfig(*s.cfg) + return s.s3Client, nil } func isNotFoundError(err error) bool { @@ -88,158 +91,111 @@ func isNotFoundError(err error) bool { return false } -func (c *S3Cache) Get(ctx context.Context, actionID string) (outputID, diskPath string, err error) { - outputID, diskPath, err = c.diskCache.Get(ctx, actionID) +func (s *S3Cache) Get(ctx context.Context, actionID string) (outputID, diskPath string, err error) { + outputID, diskPath, err = s.diskCache.Get(ctx, actionID) if err == nil && outputID != "" { return outputID, diskPath, nil } - client, err := c.client() + client, err := s.client() if err != nil { - if c.verbose { + if s.verbose { log.Printf("error getting S3 client: %v", err) } return "", "", err } - actionKey := c.actionKey(actionID) - result, err := client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: &c.Bucket, + actionKey := s.actionKey(actionID) + + outputResult, getOutputErr := client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &s.Bucket, Key: &actionKey, }) - // handle object not found - if isNotFoundError(err) { + if isNotFoundError(getOutputErr) { + // handle object not found return "", "", nil - } else if err != nil { - if c.verbose { - log.Printf("error S3 get for %s: %v", actionKey, err) + } else if getOutputErr != nil { + if s.verbose { + log.Printf("error S3 get for %s: %v", actionKey, getOutputErr) } - return "", "", fmt.Errorf("unexpected S3 get for %s: %v", actionKey, err) + return "", "", fmt.Errorf("unexpected S3 get for %s: %v", actionKey, getOutputErr) } - defer result.Body.Close() - var av ActionValue - body, err := io.ReadAll(result.Body) - if err != nil { - return "", "", err + contentSize := outputResult.ContentLength + outputID, ok := outputResult.Metadata[OutputIdKey] + if !ok || outputID == "" { + return "", "", fmt.Errorf("outputId not found in metadata") } - if err := json.Unmarshal(body, &av); err != nil { - if c.verbose { - log.Printf("error unmarshalling JSON for %s: %v", actionKey, err) + content := outputResult.Body + downloadFunc := func() error { + defer outputResult.Body.Close() + diskPath, err = s.diskCache.Put(ctx, actionID, outputID, contentSize, content) + if err != nil { + return err } - return "", "", err + return nil } - - outputID = av.OutputID - - var putBody io.Reader - if av.Size == 0 { - putBody = bytes.NewReader(nil) - diskPath, err = c.diskCache.Put(ctx, actionID, outputID, av.Size, putBody) - } else { - outputKey := c.outputKey(outputID) - outputResult, getOutputErr := client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: &c.Bucket, - Key: &outputKey, - }) - if isNotFoundError(getOutputErr) { - // handle object not found - return "", "", nil - } else if getOutputErr != nil { - if c.verbose { - log.Printf("error S3 get for %s: %v", outputKey, getOutputErr) - } - return "", "", fmt.Errorf("unexpected S3 get for %s: %v", outputKey, getOutputErr) - } - downloadFunc := func() error { - defer outputResult.Body.Close() - putBody = outputResult.Body - diskPath, err = c.diskCache.Put(ctx, actionID, outputID, av.Size, putBody) - if err != nil { - return err - } - return nil - } - if c.verbose { - speed, err := DoAndMeasureSpeed(av.Size, downloadFunc) - if err == nil { - c.downloadStatsChan <- Stats{ - Bytes: outputResult.ContentLength, - Speed: speed, - } - } else { - log.Printf("error downloading %s: %v", outputKey, err) + if s.verbose { + speed, err := DoAndMeasureSpeed(contentSize, downloadFunc) + if err == nil { + s.downloadStatsChan <- Stats{ + Bytes: contentSize, + Speed: speed, } } else { - err = downloadFunc() + log.Printf("error downloading %s: %v", actionKey, err) } + } else { + err = downloadFunc() } - return outputID, diskPath, err -} -func (c *S3Cache) actionKey(actionID string) string { - return fmt.Sprintf("%s/actions/%s", c.prefix, actionID) + return outputID, diskPath, err } -func (c *S3Cache) outputKey(outputID string) string { - return fmt.Sprintf("%s/output/%s", c.prefix, outputID) +func (s *S3Cache) actionKey(actionID string) string { + return fmt.Sprintf("%s/%s", s.prefix, actionID) } -func (c *S3Cache) Put(ctx context.Context, actionID, outputID string, size int64, body io.Reader) (diskPath string, _ error) { +func (s *S3Cache) Put(ctx context.Context, actionID, outputID string, size int64, body io.Reader) (diskPath string, _ error) { // Write to disk locally as we write it remotely, as we need to guarantee // it's on disk locally for the caller. - var readerForDisk io.Reader - var readerForS3 bytes.Buffer + var bytesReaderForDisk io.Reader + var bytesBufferForS3 bytes.Buffer if size == 0 { - // Special case the empty file so NewRequest sets "Content-Length: 0", - // as opposed to thinking we didn't set it and not being able to sniff its size - // from the type. - readerForDisk = bytes.NewReader(nil) + bytesReaderForDisk = bytes.NewReader(nil) + bytesBufferForS3 = bytes.Buffer{} } else { - readerForDisk = io.TeeReader(body, &readerForS3) + bytesReaderForDisk = io.TeeReader(body, &bytesBufferForS3) } - diskPath, err := c.diskCache.Put(ctx, actionID, outputID, size, readerForDisk) + diskPath, err := s.diskCache.Put(ctx, actionID, outputID, size, bytesReaderForDisk) if err != nil { return "", err } - - client, err := c.client() + client, err := s.client() if err != nil { return "", err } - av := ActionValue{ - OutputID: outputID, - Size: size, - } - avj, err := json.Marshal(av) - if err == nil { - actionKey := c.actionKey(actionID) - _, err = client.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &c.Bucket, - Key: &actionKey, - Body: bytes.NewReader(avj), - }) - } - if size > 0 && err == nil { - c.uploadOutput(ctx, outputID, client, readerForS3, size) - } + s.uploadOutput(ctx, actionID, outputID, client, bytesBufferForS3, size) return } -func (c *S3Cache) uploadOutput(ctx context.Context, outputID string, client *s3.Client, readerForS3 bytes.Buffer, size int64) { - outputKey := c.outputKey(outputID) +func (s *S3Cache) uploadOutput(ctx context.Context, actionId, outputID string, client *s3.Client, readerForS3 bytes.Buffer, size int64) { + outputKey := s.actionKey(actionId) putObjectFunc := func() error { _, err := client.PutObject(ctx, &s3.PutObjectInput{ - Bucket: &c.Bucket, + Bucket: &s.Bucket, Key: &outputKey, Body: &readerForS3, ContentLength: size, + Metadata: map[string]string{ + OutputIdKey: outputID, + }, }) return err } - if c.verbose { + if s.verbose { speed, err := DoAndMeasureSpeed(size, putObjectFunc) if err == nil { - c.uploadStatsChan <- Stats{ + s.uploadStatsChan <- Stats{ Bytes: size, Speed: speed, } @@ -249,20 +205,20 @@ func (c *S3Cache) uploadOutput(ctx context.Context, outputID string, client *s3. } } -func (c *S3Cache) BytesDownloaded() int64 { - return c.bytesDownloaded +func (s *S3Cache) BytesDownloaded() int64 { + return s.bytesDownloaded } -func (c *S3Cache) BytesUploaded() int64 { - return c.bytesUploaded +func (s *S3Cache) BytesUploaded() int64 { + return s.bytesUploaded } -func (c *S3Cache) AvgBytesDownloadSpeed() float64 { - return c.avgBytesDownloadSpeed +func (s *S3Cache) AvgBytesDownloadSpeed() float64 { + return s.avgBytesDownloadSpeed } -func (c *S3Cache) AvgBytesUploadSpeed() float64 { - return c.avgBytesUploadSpeed +func (s *S3Cache) AvgBytesUploadSpeed() float64 { + return s.avgBytesUploadSpeed } func newAverage(oldAverage float64, count int64, newValue float64) float64 { @@ -277,28 +233,28 @@ func DoAndMeasureSpeed(dataSize int64, functionOnData func() error) (float64, er return speed, err } -func (c *S3Cache) StartStatsGathering() { +func (s *S3Cache) StartStatsGathering() { go func() { - for s := range c.uploadStatsChan { - c.bytesUploaded += s.Bytes - c.avgBytesUploadSpeed = newAverage(c.avgBytesUploadSpeed, c.uploadCount, s.Speed) - c.uploadCount++ + for stats := range s.uploadStatsChan { + s.bytesUploaded += stats.Bytes + s.avgBytesUploadSpeed = newAverage(s.avgBytesUploadSpeed, s.uploadCount, stats.Speed) + s.uploadCount++ } - c.done <- true + s.done <- true }() go func() { - for s := range c.downloadStatsChan { - c.bytesDownloaded += s.Bytes - c.avgBytesDownloadSpeed = newAverage(c.avgBytesDownloadSpeed, c.downloadCount, s.Speed) - c.downloadCount++ + for stats := range s.downloadStatsChan { + s.bytesDownloaded += stats.Bytes + s.avgBytesDownloadSpeed = newAverage(s.avgBytesDownloadSpeed, s.downloadCount, stats.Speed) + s.downloadCount++ } - c.done <- true + s.done <- true }() } -func (c *S3Cache) Close() { - close(c.downloadStatsChan) - close(c.uploadStatsChan) - <-c.done - <-c.done +func (s *S3Cache) Close() { + close(s.downloadStatsChan) + close(s.uploadStatsChan) + <-s.done + <-s.done }