Skip to content

Commit

Permalink
Single call to S3 (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
or-shachar authored Nov 21, 2023
1 parent ec1eaf7 commit 755bc7a
Showing 1 changed file with 87 additions and 131 deletions.
218 changes: 87 additions & 131 deletions cachers/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cachers
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -18,6 +17,10 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
)

const (
OutputIdKey = "outputid"
)

type S3Cache struct {
Bucket string
cfg *aws.Config
Expand Down Expand Up @@ -69,12 +72,12 @@ func NewS3Cache(bucketName string, cfg *aws.Config, cacheKey string, disk *DiskC
return cache
}

func (c *S3Cache) client() (*s3.Client, error) {
if c.s3Client != nil {
return c.s3Client, nil
func (s *S3Cache) client() (*s3.Client, error) {
if s.s3Client != nil {
return s.s3Client, nil
}
c.s3Client = s3.NewFromConfig(*c.cfg)
return c.s3Client, nil
s.s3Client = s3.NewFromConfig(*s.cfg)
return s.s3Client, nil
}

func isNotFoundError(err error) bool {
Expand All @@ -88,158 +91,111 @@ func isNotFoundError(err error) bool {
return false
}

func (c *S3Cache) Get(ctx context.Context, actionID string) (outputID, diskPath string, err error) {
outputID, diskPath, err = c.diskCache.Get(ctx, actionID)
func (s *S3Cache) Get(ctx context.Context, actionID string) (outputID, diskPath string, err error) {
outputID, diskPath, err = s.diskCache.Get(ctx, actionID)
if err == nil && outputID != "" {
return outputID, diskPath, nil
}
client, err := c.client()
client, err := s.client()
if err != nil {
if c.verbose {
if s.verbose {
log.Printf("error getting S3 client: %v", err)
}
return "", "", err
}
actionKey := c.actionKey(actionID)
result, err := client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &c.Bucket,
actionKey := s.actionKey(actionID)

outputResult, getOutputErr := client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &s.Bucket,
Key: &actionKey,
})
// handle object not found
if isNotFoundError(err) {
if isNotFoundError(getOutputErr) {
// handle object not found
return "", "", nil
} else if err != nil {
if c.verbose {
log.Printf("error S3 get for %s: %v", actionKey, err)
} else if getOutputErr != nil {
if s.verbose {
log.Printf("error S3 get for %s: %v", actionKey, getOutputErr)
}
return "", "", fmt.Errorf("unexpected S3 get for %s: %v", actionKey, err)
return "", "", fmt.Errorf("unexpected S3 get for %s: %v", actionKey, getOutputErr)
}
defer result.Body.Close()
var av ActionValue
body, err := io.ReadAll(result.Body)
if err != nil {
return "", "", err
contentSize := outputResult.ContentLength
outputID, ok := outputResult.Metadata[OutputIdKey]
if !ok || outputID == "" {
return "", "", fmt.Errorf("outputId not found in metadata")
}
if err := json.Unmarshal(body, &av); err != nil {
if c.verbose {
log.Printf("error unmarshalling JSON for %s: %v", actionKey, err)
content := outputResult.Body
downloadFunc := func() error {
defer outputResult.Body.Close()
diskPath, err = s.diskCache.Put(ctx, actionID, outputID, contentSize, content)
if err != nil {
return err
}
return "", "", err
return nil
}

outputID = av.OutputID

var putBody io.Reader
if av.Size == 0 {
putBody = bytes.NewReader(nil)
diskPath, err = c.diskCache.Put(ctx, actionID, outputID, av.Size, putBody)
} else {
outputKey := c.outputKey(outputID)
outputResult, getOutputErr := client.GetObject(ctx, &s3.GetObjectInput{
Bucket: &c.Bucket,
Key: &outputKey,
})
if isNotFoundError(getOutputErr) {
// handle object not found
return "", "", nil
} else if getOutputErr != nil {
if c.verbose {
log.Printf("error S3 get for %s: %v", outputKey, getOutputErr)
}
return "", "", fmt.Errorf("unexpected S3 get for %s: %v", outputKey, getOutputErr)
}
downloadFunc := func() error {
defer outputResult.Body.Close()
putBody = outputResult.Body
diskPath, err = c.diskCache.Put(ctx, actionID, outputID, av.Size, putBody)
if err != nil {
return err
}
return nil
}
if c.verbose {
speed, err := DoAndMeasureSpeed(av.Size, downloadFunc)
if err == nil {
c.downloadStatsChan <- Stats{
Bytes: outputResult.ContentLength,
Speed: speed,
}
} else {
log.Printf("error downloading %s: %v", outputKey, err)
if s.verbose {
speed, err := DoAndMeasureSpeed(contentSize, downloadFunc)
if err == nil {
s.downloadStatsChan <- Stats{
Bytes: contentSize,
Speed: speed,
}
} else {
err = downloadFunc()
log.Printf("error downloading %s: %v", actionKey, err)
}
} else {
err = downloadFunc()
}
return outputID, diskPath, err
}

func (c *S3Cache) actionKey(actionID string) string {
return fmt.Sprintf("%s/actions/%s", c.prefix, actionID)
return outputID, diskPath, err
}

func (c *S3Cache) outputKey(outputID string) string {
return fmt.Sprintf("%s/output/%s", c.prefix, outputID)
func (s *S3Cache) actionKey(actionID string) string {
return fmt.Sprintf("%s/%s", s.prefix, actionID)
}

func (c *S3Cache) Put(ctx context.Context, actionID, outputID string, size int64, body io.Reader) (diskPath string, _ error) {
func (s *S3Cache) Put(ctx context.Context, actionID, outputID string, size int64, body io.Reader) (diskPath string, _ error) {
// Write to disk locally as we write it remotely, as we need to guarantee
// it's on disk locally for the caller.
var readerForDisk io.Reader
var readerForS3 bytes.Buffer
var bytesReaderForDisk io.Reader
var bytesBufferForS3 bytes.Buffer

if size == 0 {
// Special case the empty file so NewRequest sets "Content-Length: 0",
// as opposed to thinking we didn't set it and not being able to sniff its size
// from the type.
readerForDisk = bytes.NewReader(nil)
bytesReaderForDisk = bytes.NewReader(nil)
bytesBufferForS3 = bytes.Buffer{}
} else {
readerForDisk = io.TeeReader(body, &readerForS3)
bytesReaderForDisk = io.TeeReader(body, &bytesBufferForS3)
}

diskPath, err := c.diskCache.Put(ctx, actionID, outputID, size, readerForDisk)
diskPath, err := s.diskCache.Put(ctx, actionID, outputID, size, bytesReaderForDisk)
if err != nil {
return "", err
}

client, err := c.client()
client, err := s.client()
if err != nil {
return "", err
}
av := ActionValue{
OutputID: outputID,
Size: size,
}
avj, err := json.Marshal(av)
if err == nil {
actionKey := c.actionKey(actionID)
_, err = client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &c.Bucket,
Key: &actionKey,
Body: bytes.NewReader(avj),
})
}
if size > 0 && err == nil {
c.uploadOutput(ctx, outputID, client, readerForS3, size)
}
s.uploadOutput(ctx, actionID, outputID, client, bytesBufferForS3, size)
return
}

func (c *S3Cache) uploadOutput(ctx context.Context, outputID string, client *s3.Client, readerForS3 bytes.Buffer, size int64) {
outputKey := c.outputKey(outputID)
func (s *S3Cache) uploadOutput(ctx context.Context, actionId, outputID string, client *s3.Client, readerForS3 bytes.Buffer, size int64) {
outputKey := s.actionKey(actionId)
putObjectFunc := func() error {
_, err := client.PutObject(ctx, &s3.PutObjectInput{
Bucket: &c.Bucket,
Bucket: &s.Bucket,
Key: &outputKey,
Body: &readerForS3,
ContentLength: size,
Metadata: map[string]string{
OutputIdKey: outputID,
},
})
return err
}
if c.verbose {
if s.verbose {
speed, err := DoAndMeasureSpeed(size, putObjectFunc)
if err == nil {
c.uploadStatsChan <- Stats{
s.uploadStatsChan <- Stats{
Bytes: size,
Speed: speed,
}
Expand All @@ -249,20 +205,20 @@ func (c *S3Cache) uploadOutput(ctx context.Context, outputID string, client *s3.
}
}

func (c *S3Cache) BytesDownloaded() int64 {
return c.bytesDownloaded
func (s *S3Cache) BytesDownloaded() int64 {
return s.bytesDownloaded
}

func (c *S3Cache) BytesUploaded() int64 {
return c.bytesUploaded
func (s *S3Cache) BytesUploaded() int64 {
return s.bytesUploaded
}

func (c *S3Cache) AvgBytesDownloadSpeed() float64 {
return c.avgBytesDownloadSpeed
func (s *S3Cache) AvgBytesDownloadSpeed() float64 {
return s.avgBytesDownloadSpeed
}

func (c *S3Cache) AvgBytesUploadSpeed() float64 {
return c.avgBytesUploadSpeed
func (s *S3Cache) AvgBytesUploadSpeed() float64 {
return s.avgBytesUploadSpeed
}

func newAverage(oldAverage float64, count int64, newValue float64) float64 {
Expand All @@ -277,28 +233,28 @@ func DoAndMeasureSpeed(dataSize int64, functionOnData func() error) (float64, er
return speed, err
}

func (c *S3Cache) StartStatsGathering() {
func (s *S3Cache) StartStatsGathering() {
go func() {
for s := range c.uploadStatsChan {
c.bytesUploaded += s.Bytes
c.avgBytesUploadSpeed = newAverage(c.avgBytesUploadSpeed, c.uploadCount, s.Speed)
c.uploadCount++
for stats := range s.uploadStatsChan {
s.bytesUploaded += stats.Bytes
s.avgBytesUploadSpeed = newAverage(s.avgBytesUploadSpeed, s.uploadCount, stats.Speed)
s.uploadCount++
}
c.done <- true
s.done <- true
}()
go func() {
for s := range c.downloadStatsChan {
c.bytesDownloaded += s.Bytes
c.avgBytesDownloadSpeed = newAverage(c.avgBytesDownloadSpeed, c.downloadCount, s.Speed)
c.downloadCount++
for stats := range s.downloadStatsChan {
s.bytesDownloaded += stats.Bytes
s.avgBytesDownloadSpeed = newAverage(s.avgBytesDownloadSpeed, s.downloadCount, stats.Speed)
s.downloadCount++
}
c.done <- true
s.done <- true
}()
}

func (c *S3Cache) Close() {
close(c.downloadStatsChan)
close(c.uploadStatsChan)
<-c.done
<-c.done
func (s *S3Cache) Close() {
close(s.downloadStatsChan)
close(s.uploadStatsChan)
<-s.done
<-s.done
}

0 comments on commit 755bc7a

Please sign in to comment.