diff --git a/stores/multipart.go b/stores/multipart.go index 3a5bcd54a..3b2f09c29 100644 --- a/stores/multipart.go +++ b/stores/multipart.go @@ -187,17 +187,18 @@ func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMark limit++ } - prefixExpr := exprTRUE - if prefix != "" { - prefixExpr = gorm.Expr("SUBSTR(object_id, 1, ?) = ?", utf8.RuneCountInString(prefix), prefix) + // both markers must be used together + if (keyMarker == "" && uploadIDMarker != "") || (keyMarker != "" && uploadIDMarker == "") { + return api.MultipartListUploadsResponse{}, errors.New("both keyMarker and uploadIDMarker must be set or neither") } - keyMarkerExpr := exprTRUE + markerExpr := exprTRUE if keyMarker != "" { - keyMarkerExpr = gorm.Expr("object_id > ?", keyMarker) + markerExpr = gorm.Expr("object_id > ? OR (object_id = ? AND upload_id > ?)", keyMarker, keyMarker, uploadIDMarker) } - uploadIDMarkerExpr := exprTRUE - if uploadIDMarker != "" { - uploadIDMarkerExpr = gorm.Expr("upload_id > ?", keyMarker) + + prefixExpr := exprTRUE + if prefix != "" { + prefixExpr = gorm.Expr("SUBSTR(object_id, 1, ?) = ?", utf8.RuneCountInString(prefix), prefix) } err = s.retryTransaction(func(tx *gorm.DB) error { @@ -205,7 +206,10 @@ func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMark err := tx. Model(&dbMultipartUpload{}). Joins("DBBucket"). - Where("? AND ? AND ? AND DBBucket.name = ?", prefixExpr, keyMarkerExpr, uploadIDMarkerExpr, bucket). + Where("DBBucket.name", bucket). + Where("?", markerExpr). + Where("?", prefixExpr). + Order("object_id ASC, upload_id ASC"). Limit(limit). Find(&dbUploads). Error diff --git a/stores/multipart_test.go b/stores/multipart_test.go index eeda43229..37b294418 100644 --- a/stores/multipart_test.go +++ b/stores/multipart_test.go @@ -4,6 +4,8 @@ import ( "context" "encoding/hex" "reflect" + "sort" + "strings" "testing" "time" @@ -168,3 +170,99 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { t.Fatalf("expected object total size to be %v, got %v", totalSize, obj.TotalSize()) } } + +func TestMultipartUploads(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // create 3 multipart uploads, the first 2 have the same path + resp1, err := ss.CreateMultipartUpload(context.Background(), api.DefaultBucketName, "/foo", object.NoOpKey, testMimeType, testMetadata) + if err != nil { + t.Fatal(err) + } + resp2, err := ss.CreateMultipartUpload(context.Background(), api.DefaultBucketName, "/foo", object.NoOpKey, testMimeType, testMetadata) + if err != nil { + t.Fatal(err) + } + resp3, err := ss.CreateMultipartUpload(context.Background(), api.DefaultBucketName, "/foo2", object.NoOpKey, testMimeType, testMetadata) + if err != nil { + t.Fatal(err) + } + + // prepare the expected order of uploads returned by MultipartUploads + orderedUploads := []struct { + uploadID string + objectID string + }{ + {uploadID: resp1.UploadID, objectID: "/foo"}, + {uploadID: resp2.UploadID, objectID: "/foo"}, + {uploadID: resp3.UploadID, objectID: "/foo2"}, + } + sort.Slice(orderedUploads, func(i, j int) bool { + if orderedUploads[i].objectID != orderedUploads[j].objectID { + return strings.Compare(orderedUploads[i].objectID, orderedUploads[j].objectID) < 0 + } + return strings.Compare(orderedUploads[i].uploadID, orderedUploads[j].uploadID) < 0 + }) + + // fetch uploads + mur, err := ss.MultipartUploads(context.Background(), api.DefaultBucketName, "", "", "", 3) + if err != nil { + t.Fatal(err) + } else if len(mur.Uploads) != 3 { + t.Fatal("expected 3 uploads") + } else if mur.Uploads[0].UploadID != orderedUploads[0].uploadID { + t.Fatal("unexpected upload id") + } else if mur.Uploads[1].UploadID != orderedUploads[1].uploadID { + t.Fatal("unexpected upload id") + } else if mur.Uploads[2].UploadID != orderedUploads[2].uploadID { + t.Fatal("unexpected upload id") + } + + // fetch uploads with prefix + mur, err = ss.MultipartUploads(context.Background(), api.DefaultBucketName, "/foo", "", "", 3) + if err != nil { + t.Fatal(err) + } else if len(mur.Uploads) != 3 { + t.Fatal("expected 3 uploads") + } else if mur.Uploads[0].UploadID != orderedUploads[0].uploadID { + t.Fatal("unexpected upload id") + } else if mur.Uploads[1].UploadID != orderedUploads[1].uploadID { + t.Fatal("unexpected upload id") + } else if mur.Uploads[2].UploadID != orderedUploads[2].uploadID { + t.Fatal("unexpected upload id") + } + mur, err = ss.MultipartUploads(context.Background(), api.DefaultBucketName, "/foo2", "", "", 3) + if err != nil { + t.Fatal(err) + } else if len(mur.Uploads) != 1 { + t.Fatal("expected 1 upload") + } else if mur.Uploads[0].UploadID != orderedUploads[2].uploadID { + t.Fatal("unexpected upload id") + } + + // paginate through them one-by-one + keyMarker := "" + uploadIDMarker := "" + hasMore := true + for hasMore { + mur, err = ss.MultipartUploads(context.Background(), api.DefaultBucketName, "", keyMarker, uploadIDMarker, 1) + if err != nil { + t.Fatal(err) + } else if len(mur.Uploads) != 1 { + t.Fatal("expected 1 upload") + } else if mur.Uploads[0].UploadID != orderedUploads[0].uploadID { + t.Fatalf("unexpected upload id: %v != %v", mur.Uploads[0].UploadID, orderedUploads[0].uploadID) + } + orderedUploads = orderedUploads[1:] + keyMarker = mur.NextPathMarker + uploadIDMarker = mur.NextUploadIDMarker + hasMore = mur.HasMore + } + if len(orderedUploads) != 0 { + t.Fatal("expected 3 iterations") + } +}