diff --git a/.github/workflows/project-add.yml b/.github/workflows/project-add.yml new file mode 100644 index 000000000..3304fc0db --- /dev/null +++ b/.github/workflows/project-add.yml @@ -0,0 +1,21 @@ +name: Add issues and PRs to Sia project + +on: + issues: + types: + - opened + pull_request: + types: + - opened + +jobs: + add-to-project: + name: Add issue to project + runs-on: ubuntu-latest + steps: + - uses: actions/add-to-project@v0.5.0 + with: + # You can target a project in a different organization + # to the issue + project-url: https://github.com/orgs/SiaFoundation/projects/5 + github-token: ${{ secrets.PAT_ADD_TO_PROJECT }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c0b6f866d..73b96965a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -9,7 +9,7 @@ on: - dev tags: - 'v[0-9]+.[0-9]+.[0-9]+' - - 'v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+' + - 'v[0-9]+.[0-9]+.[0-9]+-**' concurrency: group: ${{ github.workflow }} @@ -263,7 +263,7 @@ jobs: steps: - name: Extract Tag Name id: get_tag - run: echo "::set-output name=tag_name::${GITHUB_REF#refs/tags/}" + run: echo "tag_name=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV - name: Repository Dispatch uses: peter-evans/repository-dispatch@v3 @@ -274,7 +274,7 @@ jobs: client-payload: > { "description": "Renterd: The Next-Gen Sia Renter", - "tag": "${{ steps.get_tag.outputs.tag_name }}", + "tag": "${{ env.tag_name }}", "project": "renterd", "workflow_id": "${{ github.run_id }}" } \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index ad04bb78e..ace11db65 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -102,6 +102,8 @@ linters: - typecheck - whitespace - tagliatelle + - unused + - unparam issues: # Maximum issues count per one linter. Set to 0 to disable. Default is 50. diff --git a/alerts/alerts_test.go b/alerts/alerts_test.go index 2cc20c57b..ff927ccdc 100644 --- a/alerts/alerts_test.go +++ b/alerts/alerts_test.go @@ -22,27 +22,29 @@ type testWebhookStore struct { listed int } -func (s *testWebhookStore) DeleteWebhook(wb webhooks.Webhook) error { +func (s *testWebhookStore) DeleteWebhook(_ context.Context, wb webhooks.Webhook) error { s.mu.Lock() defer s.mu.Unlock() s.deleted++ return nil } -func (s *testWebhookStore) AddWebhook(wb webhooks.Webhook) error { +func (s *testWebhookStore) AddWebhook(_ context.Context, wb webhooks.Webhook) error { s.mu.Lock() defer s.mu.Unlock() s.added++ return nil } -func (s *testWebhookStore) Webhooks() ([]webhooks.Webhook, error) { +func (s *testWebhookStore) Webhooks(_ context.Context) ([]webhooks.Webhook, error) { s.mu.Lock() defer s.mu.Unlock() s.listed++ return nil, nil } +var _ webhooks.WebhookStore = (*testWebhookStore)(nil) + func TestWebhooks(t *testing.T) { store := &testWebhookStore{} mgr, err := webhooks.NewManager(zap.NewNop().Sugar(), store) @@ -75,7 +77,7 @@ func TestWebhooks(t *testing.T) { if hookID := wh.String(); hookID != fmt.Sprintf("%v.%v.%v", wh.URL, wh.Module, "") { t.Fatalf("wrong result for wh.String(): %v != %v", wh.String(), hookID) } - err = mgr.Register(wh) + err = mgr.Register(context.Background(), wh) if err != nil { t.Fatal(err) } @@ -110,7 +112,7 @@ func TestWebhooks(t *testing.T) { } // unregister hook - if err := mgr.Delete(webhooks.Webhook{ + if err := mgr.Delete(context.Background(), webhooks.Webhook{ Event: hooks[0].Event, Module: hooks[0].Module, URL: hooks[0].URL, diff --git a/api/multipart.go b/api/multipart.go index a191b2b13..ee26567b1 100644 --- a/api/multipart.go +++ b/api/multipart.go @@ -51,6 +51,10 @@ type ( MimeType string Metadata ObjectUserMetadata } + + CompleteMultipartOptions struct { + Metadata ObjectUserMetadata + } ) type ( @@ -75,10 +79,11 @@ type ( } MultipartCompleteRequest struct { - Bucket string `json:"bucket"` - Path string `json:"path"` - UploadID string `json:"uploadID"` - Parts []MultipartCompletedPart + Bucket string `json:"bucket"` + Metadata ObjectUserMetadata `json:"metadata"` + Path string `json:"path"` + UploadID string `json:"uploadID"` + Parts []MultipartCompletedPart `json:"parts"` } MultipartCreateRequest struct { diff --git a/api/object.go b/api/object.go index 4b1993341..36cea9db8 100644 --- a/api/object.go +++ b/api/object.go @@ -92,6 +92,7 @@ type ( // HeadObjectResponse is the response type for the HEAD /worker/object endpoint. HeadObjectResponse struct { ContentType string `json:"contentType"` + Etag string `json:"eTag"` LastModified string `json:"lastModified"` Range *DownloadRange `json:"range,omitempty"` Size int64 `json:"size"` @@ -212,7 +213,8 @@ type ( } HeadObjectOptions struct { - Range DownloadRange + IgnoreDelim bool + Range DownloadRange } DownloadObjectOptions struct { @@ -310,6 +312,12 @@ func (opts DeleteObjectOptions) Apply(values url.Values) { } } +func (opts HeadObjectOptions) Apply(values url.Values) { + if opts.IgnoreDelim { + values.Set("ignoreDelim", "true") + } +} + func (opts HeadObjectOptions) ApplyHeaders(h http.Header) { if opts.Range != (DownloadRange{}) { if opts.Range.Length == -1 { diff --git a/autopilot/accounts.go b/autopilot/accounts.go index 416203de5..690c2b35d 100644 --- a/autopilot/accounts.go +++ b/autopilot/accounts.go @@ -136,22 +136,18 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { // refill accounts in separate goroutines for _, c := range contracts { - // add logging for contracts in the set - _, inSet := inContractSet[c.ID] - // launch refill if not already in progress if a.markRefillInProgress(workerID, c.HostKey) { - go func(contract api.ContractMetadata, inSet bool) { + go func(contract api.ContractMetadata) { rCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() - accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, workerID, contract) + accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, contract) if rerr != nil { - if inSet || rerr.Is(errMaxDriftExceeded) { - // register the alert on failure if the contract is in - // the set or the error is errMaxDriftExceeded + if rerr.Is(errMaxDriftExceeded) { + // register the alert if error is errMaxDriftExceeded a.ap.RegisterAlert(ctx, newAccountRefillAlert(accountID, contract, *rerr)) - a.l.Errorw(rerr.err.Error(), rerr.keysAndValues...) } + a.l.Errorw(rerr.err.Error(), rerr.keysAndValues...) } else { // dismiss alerts on success a.ap.DismissAlert(ctx, alertIDForAccount(alertAccountRefillID, accountID)) @@ -167,7 +163,7 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) { } a.markRefillDone(workerID, contract.HostKey) - }(c, inSet) + }(c) } } } @@ -188,7 +184,7 @@ func (err *refillError) Is(target error) bool { return errors.Is(err.err, target) } -func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, workerID string, contract api.ContractMetadata) (accountID rhpv3.Account, refilled bool, rerr *refillError) { +func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, contract api.ContractMetadata) (accountID rhpv3.Account, refilled bool, rerr *refillError) { wrapErr := func(err error, keysAndValues ...interface{}) *refillError { if err == nil { return nil diff --git a/autopilot/autopilot.go b/autopilot/autopilot.go index 7367003e0..c89049286 100644 --- a/autopilot/autopilot.go +++ b/autopilot/autopilot.go @@ -18,6 +18,7 @@ import ( "go.sia.tech/renterd/api" "go.sia.tech/renterd/build" "go.sia.tech/renterd/hostdb" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.sia.tech/renterd/wallet" "go.sia.tech/renterd/webhooks" @@ -299,7 +300,7 @@ func (ap *Autopilot) Run() error { // perform maintenance setChanged, err := ap.c.performContractMaintenance(ap.shutdownCtx, w) - if err != nil && isErr(err, context.Canceled) { + if err != nil && utils.IsErr(err, context.Canceled) { return } else if err != nil { ap.logger.Errorf("contract maintenance failed, err: %v", err) @@ -321,11 +322,11 @@ func (ap *Autopilot) Run() error { } // migration - ap.m.tryPerformMigrations(ap.shutdownCtx, ap.workers) + ap.m.tryPerformMigrations(ap.workers) // pruning if ap.state.cfg.Contracts.Prune { - ap.c.tryPerformPruning(ap.shutdownCtx, ap.workers) + ap.c.tryPerformPruning(ap.workers) } else { ap.logger.Debug("pruning disabled") } @@ -405,9 +406,9 @@ func (ap *Autopilot) blockUntilConfigured(interrupt <-chan time.Time) (configure cancel() // if the config was not found, or we were unable to fetch it, keep blocking - if isErr(err, context.Canceled) { + if utils.IsErr(err, context.Canceled) { return - } else if isErr(err, api.ErrAutopilotNotFound) { + } else if utils.IsErr(err, api.ErrAutopilotNotFound) { once.Do(func() { ap.logger.Info("autopilot is waiting to be configured...") }) } else if err != nil { ap.logger.Errorf("autopilot is unable to fetch its configuration from the bus, err: %v", err) @@ -438,7 +439,7 @@ func (ap *Autopilot) blockUntilOnline() (online bool) { online = len(peers) > 0 cancel() - if isErr(err, context.Canceled) { + if utils.IsErr(err, context.Canceled) { return } else if err != nil { ap.logger.Errorf("failed to get peers, err: %v", err) @@ -472,7 +473,7 @@ func (ap *Autopilot) blockUntilSynced(interrupt <-chan time.Time) (synced, block cancel() // if an error occurred, or if we're not synced, we continue - if isErr(err, context.Canceled) { + if utils.IsErr(err, context.Canceled) { return } else if err != nil { ap.logger.Errorf("failed to get consensus state, err: %v", err) @@ -631,7 +632,7 @@ func (ap *Autopilot) isStopped() bool { func (ap *Autopilot) configHandlerGET(jc jape.Context) { autopilot, err := ap.bus.Autopilot(jc.Request.Context(), ap.id) - if err != nil && strings.Contains(err.Error(), api.ErrAutopilotNotFound.Error()) { + if utils.IsErr(err, api.ErrAutopilotNotFound) { jc.Error(errors.New("autopilot is not configured yet"), http.StatusNotFound) return } @@ -653,7 +654,7 @@ func (ap *Autopilot) configHandlerPUT(jc jape.Context) { // fetch the autopilot and update its config var contractSetChanged bool autopilot, err := ap.bus.Autopilot(jc.Request.Context(), ap.id) - if err != nil && strings.Contains(err.Error(), api.ErrAutopilotNotFound.Error()) { + if utils.IsErr(err, api.ErrAutopilotNotFound) { autopilot = api.Autopilot{ID: ap.id, Config: cfg} } else { if autopilot.Config.Contracts.Set != cfg.Contracts.Set { diff --git a/autopilot/contract_pruning.go b/autopilot/contract_pruning.go index e32cd3fa0..aa0eb505f 100644 --- a/autopilot/contract_pruning.go +++ b/autopilot/contract_pruning.go @@ -9,6 +9,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/alerts" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/siad/build" ) @@ -65,14 +66,14 @@ func (pm pruneMetrics) String() string { func (pr pruneResult) toAlert() (id types.Hash256, alert *alerts.Alert) { id = alertIDForContract(alertPruningID, pr.fcid) - if shouldTrigger := pr.err != nil && !((isErr(pr.err, errInvalidMerkleProof) && build.VersionCmp(pr.version, "1.6.0") < 0) || - isErr(pr.err, api.ErrContractNotFound) || // contract got archived - isErr(pr.err, errConnectionRefused) || - isErr(pr.err, errConnectionTimedOut) || - isErr(pr.err, errConnectionResetByPeer) || - isErr(pr.err, errInvalidHandshakeSignature) || - isErr(pr.err, errNoRouteToHost) || - isErr(pr.err, errNoSuchHost)); shouldTrigger { + if shouldTrigger := pr.err != nil && !((utils.IsErr(pr.err, errInvalidMerkleProof) && build.VersionCmp(pr.version, "1.6.0") < 0) || + utils.IsErr(pr.err, api.ErrContractNotFound) || // contract got archived + utils.IsErr(pr.err, errConnectionRefused) || + utils.IsErr(pr.err, errConnectionTimedOut) || + utils.IsErr(pr.err, errConnectionResetByPeer) || + utils.IsErr(pr.err, errInvalidHandshakeSignature) || + utils.IsErr(pr.err, errNoRouteToHost) || + utils.IsErr(pr.err, errNoSuchHost)); shouldTrigger { alert = newContractPruningFailedAlert(pr.hk, pr.version, pr.fcid, pr.err) } return @@ -196,7 +197,7 @@ func (c *contractor) pruneContract(w Worker, fcid types.FileContractID) pruneRes pruned, remaining, err := w.RHPPruneContract(ctx, fcid, timeoutPruneContract) if err != nil && pruned == 0 { return pruneResult{fcid: fcid, hk: host.PublicKey, version: host.Settings.Version, err: err} - } else if err != nil && isErr(err, context.DeadlineExceeded) { + } else if err != nil && utils.IsErr(err, context.DeadlineExceeded) { err = nil } diff --git a/autopilot/contract_spending.go b/autopilot/contract_spending.go index ba144e173..cbd10f86c 100644 --- a/autopilot/contract_spending.go +++ b/autopilot/contract_spending.go @@ -20,7 +20,7 @@ func (c *contractor) contractSpending(ctx context.Context, contract api.Contract return total, nil } -func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeriod uint64) (types.Currency, error) { +func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeriod uint64) types.Currency { totalCosts := make(map[types.FileContractID]types.Currency) for _, c := range contracts { totalCosts[c.ID] = c.TotalCost @@ -41,22 +41,19 @@ func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeri for _, contract := range filtered { totalAllocated = totalAllocated.Add(contract.TotalCost) } - return totalAllocated, nil + return totalAllocated } -func (c *contractor) remainingFunds(contracts []api.Contract) (types.Currency, error) { +func (c *contractor) remainingFunds(contracts []api.Contract) types.Currency { state := c.ap.State() // find out how much we spent in the current period - spent, err := c.currentPeriodSpending(contracts, state.period) - if err != nil { - return types.ZeroCurrency, err - } + spent := c.currentPeriodSpending(contracts, state.period) // figure out remaining funds var remaining types.Currency if state.cfg.Contracts.Allowance.Cmp(spent) > 0 { remaining = state.cfg.Contracts.Allowance.Sub(spent) } - return remaining, nil + return remaining } diff --git a/autopilot/contractor.go b/autopilot/contractor.go index 4e5e8c842..83e12a206 100644 --- a/autopilot/contractor.go +++ b/autopilot/contractor.go @@ -16,6 +16,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" "go.sia.tech/renterd/hostdb" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/wallet" "go.sia.tech/renterd/worker" "go.uber.org/zap" @@ -275,7 +276,7 @@ func (c *contractor) performContractMaintenance(ctx context.Context, w Worker) ( // min score to pass checks var minScore float64 if len(hosts) > 0 { - minScore = c.calculateMinScore(ctx, candidates, state.cfg.Contracts.Amount) + minScore = c.calculateMinScore(candidates, state.cfg.Contracts.Amount) } else { c.logger.Warn("could not calculate min score, no hosts found") } @@ -323,10 +324,7 @@ func (c *contractor) performContractMaintenance(ctx context.Context, w Worker) ( } // calculate remaining funds - remaining, err := c.remainingFunds(contracts) - if err != nil { - return false, err - } + remaining := c.remainingFunds(contracts) // calculate 'limit' amount of contracts we want to renew var limit int @@ -1003,7 +1001,7 @@ func (c *contractor) runRevisionBroadcast(ctx context.Context, w Worker, allCont ctx, cancel := context.WithTimeout(ctx, timeoutBroadcastRevision) err := w.RHPBroadcast(ctx, contract.ID) cancel() - if err != nil && strings.Contains(err.Error(), "transaction has a file contract with an outdated revision number") { + if utils.IsErr(err, errors.New("transaction has a file contract with an outdated revision number")) { continue // don't log - revision was already broadcasted } else if err != nil { c.logger.Warnw(fmt.Sprintf("failed to broadcast contract revision: %v", err), @@ -1139,7 +1137,7 @@ func (c *contractor) initialContractFunding(settings rhpv2.HostSettings, txnFee, return funding } -func (c *contractor) refreshFundingEstimate(ctx context.Context, cfg api.AutopilotConfig, ci contractInfo, fee types.Currency) (types.Currency, error) { +func (c *contractor) refreshFundingEstimate(cfg api.AutopilotConfig, ci contractInfo, fee types.Currency) types.Currency { // refresh with 1.2x the funds refreshAmount := ci.contract.TotalCost.Mul64(6).Div64(5) @@ -1158,7 +1156,7 @@ func (c *contractor) refreshFundingEstimate(ctx context.Context, cfg api.Autopil "fcid", ci.contract.ID, "refreshAmount", refreshAmount, "refreshAmountCapped", refreshAmountCapped) - return refreshAmountCapped, nil + return refreshAmountCapped } func (c *contractor) renewFundingEstimate(ctx context.Context, ci contractInfo, fee types.Currency, renewing bool) (types.Currency, error) { @@ -1248,7 +1246,7 @@ func (c *contractor) renewFundingEstimate(ctx context.Context, ci contractInfo, return cappedEstimatedCost, nil } -func (c *contractor) calculateMinScore(ctx context.Context, candidates []scoredHost, numContracts uint64) float64 { +func (c *contractor) calculateMinScore(candidates []scoredHost, numContracts uint64) float64 { // return early if there's no hosts if len(candidates) == 0 { c.logger.Warn("min host score is set to the smallest non-zero float because there are no candidate hosts") @@ -1425,7 +1423,7 @@ func (c *contractor) renewContract(ctx context.Context, w Worker, ci contractInf "renterFunds", renterFunds, "expectedNewStorage", expectedNewStorage, ) - if strings.Contains(err.Error(), wallet.ErrInsufficientBalance.Error()) { + if utils.IsErr(err, wallet.ErrInsufficientBalance) && !worker.IsErrHost(err) { return api.ContractMetadata{}, false, err } return api.ContractMetadata{}, true, err @@ -1474,11 +1472,7 @@ func (c *contractor) refreshContract(ctx context.Context, w Worker, ci contractI // calculate the renter funds var renterFunds types.Currency if isOutOfFunds(state.cfg, ci.priceTable, ci.contract) { - renterFunds, err = c.refreshFundingEstimate(ctx, state.cfg, ci, state.fee) - if err != nil { - c.logger.Errorw(fmt.Sprintf("could not get refresh funding estimate, err: %v", err), "hk", hk, "fcid", fcid) - return api.ContractMetadata{}, true, err - } + renterFunds = c.refreshFundingEstimate(state.cfg, ci, state.fee) } else { renterFunds = rev.ValidRenterPayout() // don't increase funds } @@ -1508,7 +1502,7 @@ func (c *contractor) refreshContract(ctx context.Context, w Worker, ci contractI return api.ContractMetadata{}, true, err } c.logger.Errorw("refresh failed", zap.Error(err), "hk", hk, "fcid", fcid) - if strings.Contains(err.Error(), wallet.ErrInsufficientBalance.Error()) { + if utils.IsErr(err, wallet.ErrInsufficientBalance) && !worker.IsErrHost(err) { return api.ContractMetadata{}, false, err } return api.ContractMetadata{}, true, err @@ -1598,7 +1592,7 @@ func (c *contractor) formContract(ctx context.Context, w Worker, host hostdb.Hos return formedContract, true, nil } -func (c *contractor) tryPerformPruning(ctx context.Context, wp *workerPool) { +func (c *contractor) tryPerformPruning(wp *workerPool) { c.mu.Lock() if c.pruning || c.ap.isStopped() { c.mu.Unlock() diff --git a/autopilot/contractor_test.go b/autopilot/contractor_test.go index a0f63425b..575605612 100644 --- a/autopilot/contractor_test.go +++ b/autopilot/contractor_test.go @@ -1,7 +1,6 @@ package autopilot import ( - "context" "math" "testing" @@ -19,19 +18,19 @@ func TestCalculateMinScore(t *testing.T) { } // Test with 100 hosts which makes for a random set size of 250 - minScore := c.calculateMinScore(context.Background(), candidates, 100) + minScore := c.calculateMinScore(candidates, 100) if minScore != 0.002 { t.Fatalf("expected minScore to be 0.002 but was %v", minScore) } // Test with 0 hosts - minScore = c.calculateMinScore(context.Background(), []scoredHost{}, 100) + minScore = c.calculateMinScore([]scoredHost{}, 100) if minScore != math.SmallestNonzeroFloat64 { t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore) } // Test with 300 hosts which is 50 more than we have - minScore = c.calculateMinScore(context.Background(), candidates, 300) + minScore = c.calculateMinScore(candidates, 300) if minScore != math.SmallestNonzeroFloat64 { t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore) } diff --git a/autopilot/hostscore.go b/autopilot/hostscore.go index b15857d19..e8d9ca9b9 100644 --- a/autopilot/hostscore.go +++ b/autopilot/hostscore.go @@ -36,7 +36,7 @@ func hostScore(cfg api.AutopilotConfig, h hostdb.Host, storedData uint64, expect Collateral: collateralScore(cfg, h.PriceTable.HostPriceTable, uint64(allocationPerHost)), Interactions: interactionScore(h), Prices: priceAdjustmentScore(hostPeriodCost, cfg), - StorageRemaining: storageRemainingScore(cfg, h.Settings, storedData, expectedRedundancy, allocationPerHost), + StorageRemaining: storageRemainingScore(h.Settings, storedData, allocationPerHost), Uptime: uptimeScore(h), Version: versionScore(h.Settings), } @@ -74,7 +74,7 @@ func priceAdjustmentScore(hostCostPerPeriod types.Currency, cfg api.AutopilotCon panic("unreachable") } -func storageRemainingScore(cfg api.AutopilotConfig, h rhpv2.HostSettings, storedData uint64, expectedRedundancy, allocationPerHost float64) float64 { +func storageRemainingScore(h rhpv2.HostSettings, storedData uint64, allocationPerHost float64) float64 { // hostExpectedStorage is the amount of storage that we expect to be able to // store on this host overall, which should include the stored data that is // already on the host. @@ -291,7 +291,7 @@ func uploadCostForScore(cfg api.AutopilotConfig, h hostdb.Host, bytes uint64) ty return uploadSectorCostRHPv3.Mul64(numSectors) } -func downloadCostForScore(cfg api.AutopilotConfig, h hostdb.Host, bytes uint64) types.Currency { +func downloadCostForScore(h hostdb.Host, bytes uint64) types.Currency { rsc := h.PriceTable.BaseCost().Add(h.PriceTable.ReadSectorCost(rhpv2.SectorSize)) downloadSectorCostRHPv3, _ := rsc.Total() numSectors := bytesToSectors(bytes) @@ -314,7 +314,7 @@ func hostPeriodCostForScore(h hostdb.Host, cfg api.AutopilotConfig, expectedRedu hostCollateral := rhpv2.ContractFormationCollateral(cfg.Contracts.Period, storagePerHost, h.Settings) hostContractPrice := contractPriceForScore(h) hostUploadCost := uploadCostForScore(cfg, h, uploadPerHost) - hostDownloadCost := downloadCostForScore(cfg, h, downloadPerHost) + hostDownloadCost := downloadCostForScore(h, downloadPerHost) hostStorageCost := storageCostForScore(cfg, h, storagePerHost) siafundFee := hostCollateral. Add(hostContractPrice). diff --git a/autopilot/ipfilter.go b/autopilot/ipfilter.go index 6aa244047..0932d7676 100644 --- a/autopilot/ipfilter.go +++ b/autopilot/ipfilter.go @@ -9,6 +9,7 @@ import ( "time" "go.sia.tech/core/types" + "go.sia.tech/renterd/internal/utils" "go.uber.org/zap" ) @@ -137,7 +138,7 @@ func (r *ipResolver) lookup(hostIP string) ([]string, error) { addrs, err := r.resolver.LookupIPAddr(ctx, host) if err != nil { // check the cache if it's an i/o timeout or server misbehaving error - if isErr(err, errIOTimeout) || isErr(err, errServerMisbehaving) { + if utils.IsErr(err, errIOTimeout) || utils.IsErr(err, errServerMisbehaving) { if entry, found := r.cache[hostIP]; found && time.Since(entry.created) < ipCacheEntryValidity { r.logger.Debugf("using cached IP addresses for %v, err: %v", hostIP, err) return entry.subnets, nil @@ -188,10 +189,3 @@ func parseSubnets(addresses []net.IPAddr) []string { return subnets } - -func isErr(err error, target error) bool { - if errors.Is(err, target) { - return true - } - return err != nil && target != nil && strings.Contains(err.Error(), target.Error()) -} diff --git a/autopilot/ipfilter_test.go b/autopilot/ipfilter_test.go new file mode 100644 index 000000000..29fc3c8cf --- /dev/null +++ b/autopilot/ipfilter_test.go @@ -0,0 +1,156 @@ +package autopilot + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/internal/utils" + "go.uber.org/zap" +) + +var ( + ipv4Localhost = net.IP{127, 0, 0, 1} + ipv6Localhost = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} +) + +type testResolver struct { + addr map[string][]net.IPAddr + err error +} + +func (r *testResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { + // return error if set + if err := r.err; err != nil { + r.err = nil + return nil, err + } + // return IP addr if set + if addrs, ok := r.addr[host]; ok { + return addrs, nil + } + return nil, nil +} + +func (r *testResolver) setNextErr(err error) { r.err = err } +func (r *testResolver) setAddr(host string, addrs []net.IPAddr) { r.addr[host] = addrs } + +func newTestResolver() *testResolver { + return &testResolver{addr: make(map[string][]net.IPAddr)} +} + +func newTestIPResolver(r resolver) *ipResolver { + ipr := newIPResolver(context.Background(), time.Minute, zap.NewNop().Sugar()) + ipr.resolver = r + return ipr +} + +func newTestIPFilter(r resolver) *ipFilter { + return &ipFilter{ + subnetToHostKey: make(map[string]string), + resolver: newTestIPResolver(r), + logger: zap.NewNop().Sugar(), + } +} + +func TestIPResolver(t *testing.T) { + r := newTestResolver() + ipr := newTestIPResolver(r) + + // test lookup error + r.setNextErr(errors.New("unknown error")) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errors.New("unknown error")) { + t.Fatal("unexpected error", err) + } + + // test IO timeout - no cache entry + r.setNextErr(errIOTimeout) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errIOTimeout) { + t.Fatal("unexpected error", err) + } + + // test IO timeout - expired cache entry + ipr.cache["example.com:1234"] = ipCacheEntry{subnets: []string{"a"}} + r.setNextErr(errIOTimeout) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errIOTimeout) { + t.Fatal("unexpected error", err) + } + + // test IO timeout - live cache entry + ipr.cache["example.com:1234"] = ipCacheEntry{created: time.Now(), subnets: []string{"a"}} + r.setNextErr(errIOTimeout) + if subnets, err := ipr.lookup("example.com:1234"); err != nil { + t.Fatal("unexpected error", err) + } else if len(subnets) != 1 || subnets[0] != "a" { + t.Fatal("unexpected subnets", subnets) + } + + // test too many addresses - more than two + r.setAddr("example.com", []net.IPAddr{{}, {}, {}}) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errTooManyAddresses) { + t.Fatal("unexpected error", err) + } + + // test too many addresses - two of the same type + r.setAddr("example.com", []net.IPAddr{{IP: net.IPv4(1, 2, 3, 4)}, {IP: net.IPv4(1, 2, 3, 4)}}) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errTooManyAddresses) { + t.Fatal("unexpected error", err) + } + + // test invalid addresses + r.setAddr("example.com", []net.IPAddr{{IP: ipv4Localhost}, {IP: net.IP{127, 0, 0, 2}}}) + if _, err := ipr.lookup("example.com:1234"); !utils.IsErr(err, errTooManyAddresses) { + t.Fatal("unexpected error", err) + } + + // test valid addresses + r.setAddr("example.com", []net.IPAddr{{IP: ipv4Localhost}, {IP: ipv6Localhost}}) + if subnets, err := ipr.lookup("example.com:1234"); err != nil { + t.Fatal("unexpected error", err) + } else if len(subnets) != 2 || subnets[0] != "127.0.0.0/24" || subnets[1] != "::/32" { + t.Fatal("unexpected subnets", subnets) + } +} + +func TestIPFilter(t *testing.T) { + r := newTestResolver() + r.setAddr("host1.com", []net.IPAddr{{IP: net.IP{192, 168, 0, 1}}}) + r.setAddr("host2.com", []net.IPAddr{{IP: net.IP{192, 168, 1, 1}}}) + r.setAddr("host3.com", []net.IPAddr{{IP: net.IP{192, 168, 2, 1}}}) + ipf := newTestIPFilter(r) + + // add 3 hosts - unique IPs + r1 := ipf.IsRedundantIP("host1.com:1234", types.PublicKey{1}) + r2 := ipf.IsRedundantIP("host2.com:1234", types.PublicKey{2}) + r3 := ipf.IsRedundantIP("host3.com:1234", types.PublicKey{3}) + if r1 || r2 || r3 { + t.Fatal("unexpected result", r1, r2, r3) + } + + // try add 4th host - redundant IP + r.setAddr("host4.com", []net.IPAddr{{IP: net.IP{192, 168, 0, 12}}}) + if redundant := ipf.IsRedundantIP("host4.com:1234", types.PublicKey{4}); !redundant { + t.Fatal("unexpected result", redundant) + } + + // add 4th host - unique IP - 2 subnets + r.setAddr("host4.com", []net.IPAddr{{IP: net.IP{192, 168, 3, 1}}, {IP: net.ParseIP("2001:0db8:85a3::8a2e:0370:7334")}}) + if redundant := ipf.IsRedundantIP("host4.com:1234", types.PublicKey{4}); redundant { + t.Fatal("unexpected result", redundant) + } + + // try add 5th host - redundant IP based on the IPv6 subnet from host4 + r.setAddr("host5.com", []net.IPAddr{{IP: net.ParseIP("2001:0db8:85b3::8a2e:0370:7335")}}) + if redundant := ipf.IsRedundantIP("host5.com:1234", types.PublicKey{5}); !redundant { + t.Fatal("unexpected result", redundant) + } + + // add 5th host - unique IP + r.setAddr("host5.com", []net.IPAddr{{IP: net.ParseIP("2001:0db9:85b3::8a2e:0370:7335")}}) + if redundant := ipf.IsRedundantIP("host5.com:1234", types.PublicKey{5}); redundant { + t.Fatal("unexpected result", redundant) + } +} diff --git a/autopilot/migrator.go b/autopilot/migrator.go index 4a4e31de6..89ab16a28 100644 --- a/autopilot/migrator.go +++ b/autopilot/migrator.go @@ -10,6 +10,7 @@ import ( "time" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.sia.tech/renterd/stats" "go.uber.org/zap" @@ -97,7 +98,7 @@ func (m *migrator) slabMigrationEstimate(remaining int) time.Duration { return time.Duration(totalNumMS) * time.Millisecond } -func (m *migrator) tryPerformMigrations(ctx context.Context, wp *workerPool) { +func (m *migrator) tryPerformMigrations(wp *workerPool) { m.mu.Lock() if m.migrating || m.ap.isStopped() { m.mu.Unlock() @@ -156,7 +157,7 @@ func (m *migrator) performMigrations(p *workerPool) { if err != nil { m.logger.Errorf("%v: migration %d/%d failed, key: %v, health: %v, overpaid: %v, err: %v", id, j.slabIdx+1, j.batchSize, j.Key, j.Health, res.SurchargeApplied, err) - skipAlert := isErr(err, api.ErrSlabNotFound) + skipAlert := utils.IsErr(err, api.ErrSlabNotFound) if !skipAlert { if res.SurchargeApplied { m.ap.RegisterAlert(ctx, newCriticalMigrationFailedAlert(j.Key, j.Health, err)) diff --git a/autopilot/scanner.go b/autopilot/scanner.go index e512d1f87..bb21e5022 100644 --- a/autopilot/scanner.go +++ b/autopilot/scanner.go @@ -12,6 +12,7 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" "go.sia.tech/renterd/hostdb" + "go.sia.tech/renterd/internal/utils" "go.uber.org/zap" ) @@ -162,9 +163,9 @@ func (s *scanner) isInterrupted() bool { } } -func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bool) bool { +func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bool) { if s.ap.isStopped() { - return false + return } scanType := "host scan" @@ -184,7 +185,7 @@ func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bo s.interruptScanChan = make(chan struct{}) } else if s.scanning || !s.isScanRequired() { s.mu.Unlock() - return false + return } s.scanningLastStart = time.Now() s.scanning = true @@ -228,7 +229,7 @@ func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bo s.logger.Debugf("%s finished after %v", st, time.Since(s.scanningLastStart)) s.mu.Unlock() }(scanType) - return true + return } func (s *scanner) tryUpdateTimeout() { @@ -314,7 +315,7 @@ func (s *scanner) launchScanWorkers(ctx context.Context, w scanWorker, reqs chan scan, err := w.RHPScan(ctx, req.hostKey, req.hostIP, s.currentTimeout()) if err != nil { break // abort - } else if !isErr(errors.New(scan.ScanError), errIOTimeout) && scan.Ping > 0 { + } else if !utils.IsErr(errors.New(scan.ScanError), errIOTimeout) && scan.Ping > 0 { s.tracker.addDataPoint(time.Duration(scan.Ping)) } diff --git a/autopilot/scanner_test.go b/autopilot/scanner_test.go index d5833d1fb..6214ec4a1 100644 --- a/autopilot/scanner_test.go +++ b/autopilot/scanner_test.go @@ -87,7 +87,7 @@ func TestScanner(t *testing.T) { // init new scanner b := &mockBus{hosts: hosts} w := &mockWorker{blockChan: make(chan struct{})} - s := newTestScanner(b, w) + s := newTestScanner(b) // assert it started a host scan s.tryPerformHostScan(context.Background(), w, false) @@ -139,7 +139,7 @@ func (s *scanner) isScanning() bool { return s.scanning } -func newTestScanner(b *mockBus, w *mockWorker) *scanner { +func newTestScanner(b *mockBus) *scanner { ap := &Autopilot{} ap.shutdownCtx, ap.shutdownCtxCancel = context.WithCancel(context.Background()) return &scanner{ diff --git a/bus/bus.go b/bus/bus.go index 045b8e82a..d8b3fdfc5 100644 --- a/bus/bus.go +++ b/bus/bus.go @@ -150,7 +150,7 @@ type ( AbortMultipartUpload(ctx context.Context, bucketName, path string, uploadID string) (err error) AddMultipartPart(ctx context.Context, bucketName, path, contractSet, eTag, uploadID string, partNumber int, slices []object.SlabSlice) (err error) - CompleteMultipartUpload(ctx context.Context, bucketName, path, uploadID string, parts []api.MultipartCompletedPart) (_ api.MultipartCompleteResponse, err error) + CompleteMultipartUpload(ctx context.Context, bucketName, path, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (_ api.MultipartCompleteResponse, err error) CreateMultipartUpload(ctx context.Context, bucketName, path string, ec object.EncryptionKey, mimeType string, metadata api.ObjectUserMetadata) (api.MultipartCreateResponse, error) MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, _ error) MultipartUploads(ctx context.Context, bucketName, prefix, keyMarker, uploadIDMarker string, maxUploads int) (resp api.MultipartListUploadsResponse, _ error) @@ -189,7 +189,7 @@ type ( EphemeralAccountStore interface { Accounts(context.Context) ([]api.Account, error) SaveAccounts(context.Context, []api.Account) error - SetUncleanShutdown() error + SetUncleanShutdown(context.Context) error } MetricsStore interface { @@ -2022,7 +2022,7 @@ func (b *bus) webhookHandlerDelete(jc jape.Context) { if jc.Decode(&wh) != nil { return } - err := b.hooks.Delete(wh) + err := b.hooks.Delete(jc.Request.Context(), wh) if errors.Is(err, webhooks.ErrWebhookNotFound) { jc.Error(fmt.Errorf("webhook for URL %v and event %v.%v not found", wh.URL, wh.Module, wh.Event), http.StatusNotFound) return @@ -2044,7 +2044,7 @@ func (b *bus) webhookHandlerPost(jc jape.Context) { if jc.Decode(&req) != nil { return } - err := b.hooks.Register(webhooks.Webhook{ + err := b.hooks.Register(jc.Request.Context(), webhooks.Webhook{ Event: req.Event, Module: req.Module, URL: req.URL, @@ -2244,7 +2244,9 @@ func (b *bus) multipartHandlerCompletePOST(jc jape.Context) { if jc.Decode(&req) != nil { return } - resp, err := b.ms.CompleteMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.Parts) + resp, err := b.ms.CompleteMultipartUpload(jc.Request.Context(), req.Bucket, req.Path, req.UploadID, req.Parts, api.CompleteMultipartOptions{ + Metadata: req.Metadata, + }) if jc.Check("failed to complete multipart upload", err) != nil { return } @@ -2410,7 +2412,7 @@ func New(s Syncer, am *alerts.Manager, hm *webhooks.Manager, cm ChainManager, tp // mark the shutdown as unclean, this will be overwritten when/if the // accounts are saved on shutdown - if err := eas.SetUncleanShutdown(); err != nil { + if err := eas.SetUncleanShutdown(ctx); err != nil { return nil, fmt.Errorf("failed to mark account shutdown as unclean: %w", err) } return b, nil diff --git a/bus/client/multipart-upload.go b/bus/client/multipart-upload.go index 281019487..6fd06204c 100644 --- a/bus/client/multipart-upload.go +++ b/bus/client/multipart-upload.go @@ -33,10 +33,11 @@ func (c *Client) AddMultipartPart(ctx context.Context, bucket, path, contractSet } // CompleteMultipartUpload completes a multipart upload. -func (c *Client) CompleteMultipartUpload(ctx context.Context, bucket, path, uploadID string, parts []api.MultipartCompletedPart) (resp api.MultipartCompleteResponse, err error) { +func (c *Client) CompleteMultipartUpload(ctx context.Context, bucket, path, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (resp api.MultipartCompleteResponse, err error) { err = c.c.WithContext(ctx).POST("/multipart/complete", api.MultipartCompleteRequest{ Bucket: bucket, Path: path, + Metadata: opts.Metadata, UploadID: uploadID, Parts: parts, }, &resp) diff --git a/bus/client/objects.go b/bus/client/objects.go index 23011a9ba..6a17691e2 100644 --- a/bus/client/objects.go +++ b/bus/client/objects.go @@ -112,7 +112,7 @@ func (c *Client) SearchObjects(ctx context.Context, bucket string, opts api.Sear } func (c *Client) renameObjects(ctx context.Context, bucket, from, to, mode string, force bool) (err error) { - err = c.c.POST("/objects/rename", api.ObjectsRenameRequest{ + err = c.c.WithContext(ctx).POST("/objects/rename", api.ObjectsRenameRequest{ Bucket: bucket, Force: force, From: from, diff --git a/bus/contractlocking_test.go b/bus/contractlocking_test.go index a00198cc9..120ca9ca2 100644 --- a/bus/contractlocking_test.go +++ b/bus/contractlocking_test.go @@ -154,7 +154,7 @@ func TestContractKeepalive(t *testing.T) { func TestContractRelease(t *testing.T) { locks := newContractLocks() - verify := func(fcid types.FileContractID, lockID uint64, lockedUntil time.Time, delta time.Duration) { + verify := func(fcid types.FileContractID, lockID uint64) { t.Helper() lock := locks.lockForContractID(fcid, false) if lock.heldByID != lockID { @@ -168,7 +168,7 @@ func TestContractRelease(t *testing.T) { if err != nil { t.Fatal(err) } - verify(fcid, lockID, time.Now().Add(time.Minute), 3*time.Second) + verify(fcid, lockID) // Acquire it again but release the contract within a second. var wg sync.WaitGroup @@ -185,14 +185,14 @@ func TestContractRelease(t *testing.T) { if err != nil { t.Fatal(err) } - verify(fcid, lockID, time.Now().Add(time.Minute), 3*time.Second) + verify(fcid, lockID) // Release one more time. Should decrease the references to 0 and reset // fields. if err := locks.Release(fcid, lockID); err != nil { t.Error(err) } - verify(fcid, 0, time.Time{}, 0) + verify(fcid, 0) // Try to release lock again. Is a no-op. if err := locks.Release(fcid, lockID); err != nil { diff --git a/cmd/renterd/config.go b/cmd/renterd/config.go index 47668ff94..f9008a4d5 100644 --- a/cmd/renterd/config.go +++ b/cmd/renterd/config.go @@ -41,6 +41,7 @@ func readInput(context string) string { } // wrapANSI wraps the output in ANSI escape codes if enabled. +// nolint: unparam func wrapANSI(prefix, output, suffix string) string { if enableANSI { return prefix + output + suffix diff --git a/cmd/renterd/main.go b/cmd/renterd/main.go index 98e075d92..093747796 100644 --- a/cmd/renterd/main.go +++ b/cmd/renterd/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "flag" "fmt" "log" @@ -24,6 +25,7 @@ import ( "go.sia.tech/renterd/bus" "go.sia.tech/renterd/config" "go.sia.tech/renterd/internal/node" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/s3" "go.sia.tech/renterd/stores" "go.sia.tech/renterd/worker" @@ -224,7 +226,7 @@ func parseEnvVar(s string, v interface{}) { func listenTCP(logger *zap.Logger, addr string) (net.Listener, error) { l, err := net.Listen("tcp", addr) - if err != nil && strings.Contains(err.Error(), "no such host") && strings.Contains(addr, "localhost") { + if utils.IsErr(err, errors.New("no such host")) && strings.Contains(addr, "localhost") { // fall back to 127.0.0.1 if 'localhost' doesn't work _, port, err := net.SplitHostPort(addr) if err != nil { diff --git a/go.mod b/go.mod index 43208cfc1..c317ab683 100644 --- a/go.mod +++ b/go.mod @@ -8,20 +8,20 @@ require ( github.com/google/go-cmp v0.6.0 github.com/gotd/contrib v0.19.0 github.com/klauspost/reedsolomon v1.12.1 - github.com/minio/minio-go/v7 v7.0.68 + github.com/minio/minio-go/v7 v7.0.69 github.com/montanaflynn/stats v0.7.1 gitlab.com/NebulousLabs/encoding v0.0.0-20200604091946-456c3dc907fe go.sia.tech/core v0.2.1 go.sia.tech/coreutils v0.0.3 - go.sia.tech/gofakes3 v0.0.0-20231109151325-e0d47c10dce2 - go.sia.tech/hostd v1.0.2 + go.sia.tech/gofakes3 v0.0.1 + go.sia.tech/hostd v1.0.3 go.sia.tech/jape v0.11.2-0.20240124024603-93559895d640 go.sia.tech/mux v1.2.0 go.sia.tech/siad v1.5.10-0.20230228235644-3059c0b930ca go.sia.tech/web/renterd v0.49.0 go.uber.org/zap v1.27.0 - golang.org/x/crypto v0.20.0 - golang.org/x/term v0.17.0 + golang.org/x/crypto v0.21.0 + golang.org/x/term v0.18.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.4 gorm.io/driver/sqlite v1.5.5 @@ -75,7 +75,7 @@ require ( go.sia.tech/web v0.0.0-20231213145933-3f175a86abff // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sys v0.17.0 // indirect + golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.16.1 // indirect diff --git a/go.sum b/go.sum index 9b7f15042..519a9325e 100644 --- a/go.sum +++ b/go.sum @@ -135,8 +135,8 @@ github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34= github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM= -github.com/minio/minio-go/v7 v7.0.68 h1:hTqSIfLlpXaKuNy4baAp4Jjy2sqZEN9hRxD0M4aOfrQ= -github.com/minio/minio-go/v7 v7.0.68/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ= +github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0= +github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ= github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM= github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= @@ -240,10 +240,10 @@ go.sia.tech/core v0.2.1 h1:CqmMd+T5rAhC+Py3NxfvGtvsj/GgwIqQHHVrdts/LqY= go.sia.tech/core v0.2.1/go.mod h1:3EoY+rR78w1/uGoXXVqcYdwSjSJKuEMI5bL7WROA27Q= go.sia.tech/coreutils v0.0.3 h1:ZxuzovRpQMvfy/pCOV4om1cPF6sE15GyJyK36kIrF1Y= go.sia.tech/coreutils v0.0.3/go.mod h1:UBFc77wXiE//eyilO5HLOncIEj7F69j0Nv2OkFujtP0= -go.sia.tech/gofakes3 v0.0.0-20231109151325-e0d47c10dce2 h1:ulzfJNjxN5DjXHClkW2pTiDk+eJ+0NQhX87lFDZ03t0= -go.sia.tech/gofakes3 v0.0.0-20231109151325-e0d47c10dce2/go.mod h1:PlsiVCn6+wssrR7bsOIlZm0DahsVrDydrlbjY4F14sg= -go.sia.tech/hostd v1.0.2 h1:GjzNIAlwg3/dViF6258Xn5DI3+otQLRqmkoPDugP+9Y= -go.sia.tech/hostd v1.0.2/go.mod h1:zGw+AGVmazAp4ydvo7bZLNKTy1J51RI6Mp/oxRtYT6c= +go.sia.tech/gofakes3 v0.0.1 h1:8vtYH/B17NJ4GXLWiONfhwBrrmtJtYiofnO3PfjU298= +go.sia.tech/gofakes3 v0.0.1/go.mod h1:PlsiVCn6+wssrR7bsOIlZm0DahsVrDydrlbjY4F14sg= +go.sia.tech/hostd v1.0.3 h1:BCaFg6DGf33JEH/5DqFj6cnaz3EbiyjpbhfSj/Lo6e8= +go.sia.tech/hostd v1.0.3/go.mod h1:R+01UddrgmAUcdBkEO8VcnYqPX/mod45DC5m/v/crzE= go.sia.tech/jape v0.11.2-0.20240124024603-93559895d640 h1:mSaJ622P7T/M97dAK8iPV+IRIC9M5vV28NHeceoWO3M= go.sia.tech/jape v0.11.2-0.20240124024603-93559895d640/go.mod h1:4QqmBB+t3W7cNplXPj++ZqpoUb2PeiS66RLpXmEGap4= go.sia.tech/mux v1.2.0 h1:ofa1Us9mdymBbGMY2XH/lSpY8itFsKIo/Aq8zwe+GHU= @@ -274,8 +274,8 @@ golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220507011949-2cf3adece122/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg= -golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -328,16 +328,16 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210421210424-b80969c67360/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= -golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= +golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/internal/node/node.go b/internal/node/node.go index e94cfbb4d..d105cbfb2 100644 --- a/internal/node/node.go +++ b/internal/node/node.go @@ -137,11 +137,14 @@ func NewBus(cfg BusConfig, dir string, seed types.PrivateKey, l *zap.Logger) (ht cancelSubscribe := make(chan struct{}) go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + subscribeErr := cs.ConsensusSetSubscribe(sqlStore, ccid, cancelSubscribe) if errors.Is(subscribeErr, modules.ErrInvalidConsensusChangeID) { l.Warn("Invalid consensus change ID detected - resyncing consensus") // Reset the consensus state within the database and rescan. - if err := sqlStore.ResetConsensusSubscription(); err != nil { + if err := sqlStore.ResetConsensusSubscription(ctx); err != nil { l.Fatal(fmt.Sprintf("Failed to reset consensus subscription of SQLStore: %v", err)) return } @@ -177,11 +180,8 @@ func NewBus(cfg BusConfig, dir string, seed types.PrivateKey, l *zap.Logger) (ht } shutdownFn := func(ctx context.Context) error { + close(cancelSubscribe) return errors.Join( - func() error { - close(cancelSubscribe) - return nil - }(), g.Close(), cs.Close(), tp.Close(), diff --git a/internal/test/e2e/cluster_test.go b/internal/test/e2e/cluster_test.go index 5ca7141d5..2346f7019 100644 --- a/internal/test/e2e/cluster_test.go +++ b/internal/test/e2e/cluster_test.go @@ -1098,7 +1098,7 @@ func TestParallelUpload(t *testing.T) { w := cluster.Worker tt := cluster.tt - upload := func() error { + upload := func() { t.Helper() // prepare some data - make sure it's more than one sector data := make([]byte, rhpv2.SectorSize) @@ -1107,7 +1107,6 @@ func TestParallelUpload(t *testing.T) { // upload the data path := fmt.Sprintf("/dir/data_%v", hex.EncodeToString(data[:16])) tt.OKAll(w.UploadObject(context.Background(), bytes.NewReader(data), api.DefaultBucketName, path, api.UploadObjectOptions{})) - return nil } // Upload in parallel @@ -1116,10 +1115,7 @@ func TestParallelUpload(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - if err := upload(); err != nil { - t.Error(err) - return - } + upload() }() } wg.Wait() @@ -2107,7 +2103,7 @@ func TestMultipartUploads(t *testing.T) { PartNumber: 3, ETag: etag3, }, - }) + }, api.CompleteMultipartOptions{}) tt.OK(err) if ui.ETag == "" { t.Fatal("unexpected response:", ui) @@ -2435,7 +2431,7 @@ func TestMultipartUploadWrappedByPartialSlabs(t *testing.T) { PartNumber: 3, ETag: resp3.ETag, }, - })) + }, api.CompleteMultipartOptions{})) // download the object and verify its integrity dst := new(bytes.Buffer) diff --git a/internal/test/e2e/metadata_test.go b/internal/test/e2e/metadata_test.go index d11f6ba4e..af924f847 100644 --- a/internal/test/e2e/metadata_test.go +++ b/internal/test/e2e/metadata_test.go @@ -55,6 +55,8 @@ func TestObjectMetadata(t *testing.T) { } if !reflect.DeepEqual(gor.Metadata, opts.Metadata) { t.Fatal("metadata mismatch", gor.Metadata) + } else if gor.Etag == "" { + t.Fatal("missing etag") } // perform a HEAD request and assert the headers are all present @@ -63,6 +65,7 @@ func TestObjectMetadata(t *testing.T) { t.Fatal(err) } else if !reflect.DeepEqual(hor, &api.HeadObjectResponse{ ContentType: or.Object.ContentType(), + Etag: gor.Etag, LastModified: or.Object.LastModified(), Range: &api.DownloadRange{Offset: 1, Length: 1, Size: int64(len(data))}, Size: int64(len(data)), diff --git a/internal/test/e2e/s3_test.go b/internal/test/e2e/s3_test.go index b25e11871..6c13e8426 100644 --- a/internal/test/e2e/s3_test.go +++ b/internal/test/e2e/s3_test.go @@ -3,6 +3,8 @@ package e2e import ( "bytes" "context" + "crypto/md5" + "encoding/hex" "errors" "fmt" "io" @@ -72,8 +74,12 @@ func TestS3Basic(t *testing.T) { // add object to the bucket data := frand.Bytes(10) + etag := md5.Sum(data) uploadInfo, err := s3.PutObject(context.Background(), bucket, objPath, bytes.NewReader(data), int64(len(data)), minio.PutObjectOptions{}) tt.OK(err) + if uploadInfo.ETag != hex.EncodeToString(etag[:]) { + t.Fatalf("expected ETag %v, got %v", hex.EncodeToString(etag[:]), uploadInfo.ETag) + } busObject, err := cluster.Bus.Object(context.Background(), bucket, objPath, api.GetObjectOptions{}) tt.OK(err) if busObject.Object == nil { @@ -92,6 +98,10 @@ func TestS3Basic(t *testing.T) { t.Fatal(err) } else if !bytes.Equal(b, data) { t.Fatal("data mismatch") + } else if info, err := obj.Stat(); err != nil { + t.Fatal(err) + } else if info.ETag != uploadInfo.ETag { + t.Fatal("unexpected ETag:", info.ETag, uploadInfo.ETag) } // stat object @@ -99,6 +109,8 @@ func TestS3Basic(t *testing.T) { tt.OK(err) if info.Size != int64(len(data)) { t.Fatal("size mismatch") + } else if info.ETag != uploadInfo.ETag { + t.Fatal("unexpected ETag:", info.ETag) } // add another bucket @@ -254,6 +266,38 @@ func TestS3ObjectMetadata(t *testing.T) { head, err = s3.StatObject(context.Background(), api.DefaultBucketName, t.Name(), minio.StatObjectOptions{}) tt.OK(err) assertMetadata(metadata, head.UserMetadata) + + // upload a file using multipart upload + core := cluster.S3Core + uid, err := core.NewMultipartUpload(context.Background(), api.DefaultBucketName, "multi", minio.PutObjectOptions{ + UserMetadata: map[string]string{ + "New": "1", + }, + }) + tt.OK(err) + data := frand.Bytes(3) + + part, err := core.PutObjectPart(context.Background(), api.DefaultBucketName, "foo", uid, 1, bytes.NewReader(data), int64(len(data)), minio.PutObjectPartOptions{}) + tt.OK(err) + _, err = core.CompleteMultipartUpload(context.Background(), api.DefaultBucketName, "multi", uid, []minio.CompletePart{ + { + PartNumber: part.PartNumber, + ETag: part.ETag, + }, + }, minio.PutObjectOptions{ + UserMetadata: map[string]string{ + "Complete": "2", + }, + }) + tt.OK(err) + + // check metadata + head, err = s3.StatObject(context.Background(), api.DefaultBucketName, "multi", minio.StatObjectOptions{}) + tt.OK(err) + assertMetadata(map[string]string{ + "New": "1", + "Complete": "2", + }, head.UserMetadata) } func TestS3Authentication(t *testing.T) { @@ -455,6 +499,30 @@ func TestS3List(t *testing.T) { if !cmp.Equal(test.want, got) { t.Errorf("test %d: unexpected response, want %v got %v", i, test.want, got) } + for _, obj := range result.Contents { + if obj.ETag == "" { + t.Fatal("expected non-empty ETag") + } else if obj.LastModified.IsZero() { + t.Fatal("expected non-zero LastModified") + } + } + } + + // use pagination to loop over objects one-by-one + marker := "" + expectedOrder := []string{"a/", "a/a/a", "a/b", "ab", "b", "c/a", "d", "y/", "y/y/y/y"} + hasMore := true + for i := 0; hasMore; i++ { + result, err := core.ListObjectsV2("bucket", "", "", marker, "", 1) + if err != nil { + t.Fatal(err) + } else if len(result.Contents) != 1 { + t.Fatalf("unexpected number of objects, %d != 1", len(result.Contents)) + } else if result.Contents[0].Key != expectedOrder[i] { + t.Errorf("unexpected object, %s != %s", result.Contents[0].Key, expectedOrder[i]) + } + marker = result.NextContinuationToken + hasMore = result.IsTruncated } } @@ -548,12 +616,28 @@ func TestS3MultipartUploads(t *testing.T) { } // Download object + expectedData := []byte("helloworld!") downloadedObj, err := s3.GetObject(context.Background(), "multipart", "foo", minio.GetObjectOptions{}) tt.OK(err) if data, err := io.ReadAll(downloadedObj); err != nil { t.Fatal(err) - } else if !bytes.Equal(data, []byte("helloworld!")) { + } else if !bytes.Equal(data, expectedData) { t.Fatal("unexpected data:", string(data)) + } else if info, err := downloadedObj.Stat(); err != nil { + t.Fatal(err) + } else if info.ETag != ui.ETag { + t.Fatal("unexpected ETag:", info.ETag) + } else if info.Size != int64(len(expectedData)) { + t.Fatal("unexpected size:", info.Size) + } + + // Stat object + if info, err := s3.StatObject(context.Background(), "multipart", "foo", minio.StatObjectOptions{}); err != nil { + t.Fatal(err) + } else if info.ETag != ui.ETag { + t.Fatal("unexpected ETag:", info.ETag) + } else if info.Size != int64(len(expectedData)) { + t.Fatal("unexpected size:", info.Size) } // Download again with range request. diff --git a/internal/utils/errors.go b/internal/utils/errors.go new file mode 100644 index 000000000..b884cde70 --- /dev/null +++ b/internal/utils/errors.go @@ -0,0 +1,20 @@ +package utils + +import ( + "errors" + "strings" +) + +// IsErr can be used to compare an error to a target and also works when used on +// errors that haven't been wrapped since it will fall back to a string +// comparison. Useful to check errors returned over the network. +func IsErr(err error, target error) bool { + if (err == nil) != (target == nil) { + return false + } else if errors.Is(err, target) { + return true + } + // TODO: we can get rid of the lower casing once siad is gone and + // renterd/hostd use the same error messages + return strings.Contains(strings.ToLower(err.Error()), strings.ToLower(target.Error())) +} diff --git a/object/object.go b/object/object.go index 965ebce2a..e8243fac1 100644 --- a/object/object.go +++ b/object/object.go @@ -3,7 +3,6 @@ package object import ( "bytes" "crypto/cipher" - "crypto/md5" "encoding/binary" "encoding/hex" "fmt" @@ -146,22 +145,6 @@ func (o Object) Contracts() map[types.PublicKey]map[types.FileContractID]struct{ return usedContracts } -func (o *Object) ComputeETag() string { - // calculate the eTag using the precomputed sector roots to avoid having to - // hash the entire object again. - h := md5.New() - b := make([]byte, 8) - for _, slab := range o.Slabs { - binary.LittleEndian.PutUint32(b[:4], slab.Offset) - binary.LittleEndian.PutUint32(b[4:], slab.Length) - h.Write(b) - for _, shard := range slab.Shards { - h.Write(shard.Root[:]) - } - } - return string(hex.EncodeToString(h.Sum(nil))) -} - // TotalSize returns the total size of the object. func (o Object) TotalSize() int64 { var n int64 diff --git a/s3/authentication.go b/s3/authentication.go index 9d5da4f1a..67017356b 100644 --- a/s3/authentication.go +++ b/s3/authentication.go @@ -5,11 +5,11 @@ import ( "fmt" "io" "net/http" - "strings" "go.sia.tech/gofakes3" "go.sia.tech/gofakes3/signature" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" ) var ( @@ -89,7 +89,7 @@ func newAuthenticatedBackend(b *s3) *authenticatedBackend { func (b *authenticatedBackend) applyBucketPolicy(ctx context.Context, bucketName string, p *permissions) error { bucket, err := b.backend.b.Bucket(ctx, bucketName) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return gofakes3.BucketNotFound(bucketName) } else if err != nil { return gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -264,9 +264,9 @@ func (b *authenticatedBackend) AbortMultipartUpload(ctx context.Context, bucket, return b.backend.AbortMultipartUpload(ctx, bucket, object, id) } -func (b *authenticatedBackend) CompleteMultipartUpload(ctx context.Context, bucket, object string, id gofakes3.UploadID, input *gofakes3.CompleteMultipartUploadRequest) (resp *gofakes3.CompleteMultipartUploadResult, err error) { +func (b *authenticatedBackend) CompleteMultipartUpload(ctx context.Context, bucket, object string, id gofakes3.UploadID, meta map[string]string, input *gofakes3.CompleteMultipartUploadRequest) (resp *gofakes3.CompleteMultipartUploadResult, err error) { if !b.permsFromCtx(ctx, bucket).CompleteMultipartUpload { return nil, gofakes3.ErrAccessDenied } - return b.backend.CompleteMultipartUpload(ctx, bucket, object, id, input) + return b.backend.CompleteMultipartUpload(ctx, bucket, object, id, meta, input) } diff --git a/s3/backend.go b/s3/backend.go index c05a3ec98..bb6e3ff7c 100644 --- a/s3/backend.go +++ b/s3/backend.go @@ -3,12 +3,14 @@ package s3 import ( "bytes" "context" + "encoding/hex" "fmt" "io" "strings" "go.sia.tech/gofakes3" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.uber.org/zap" ) @@ -87,6 +89,11 @@ func (s *s3) ListBucket(ctx context.Context, bucketName string, prefix *gofakes3 page.MaxKeys = maxKeysDefault } + // Adjust marker + if page.HasMarker { + page.Marker = "/" + page.Marker + } + var objects []api.ObjectMetadata var err error response := gofakes3.NewObjectList() @@ -108,7 +115,7 @@ func (s *s3) ListBucket(ctx context.Context, bucketName string, prefix *gofakes3 } var res api.ObjectsResponse res, err = s.b.Object(ctx, bucketName, path, opts) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return nil, gofakes3.BucketNotFound(bucketName) } else if err != nil { return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -128,7 +135,7 @@ func (s *s3) ListBucket(ctx context.Context, bucketName string, prefix *gofakes3 var res api.ObjectsListResponse res, err = s.b.ListObjects(ctx, bucketName, opts) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return nil, gofakes3.BucketNotFound(bucketName) } else if err != nil { return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -141,6 +148,10 @@ func (s *s3) ListBucket(ctx context.Context, bucketName string, prefix *gofakes3 return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) } + // Remove the leading slash from the marker since we also do that for the + // name of each object + response.NextMarker = strings.TrimPrefix(response.NextMarker, "/") + // Loop over the entries and add them to the response. for _, object := range objects { key := strings.TrimPrefix(object.Name, "/") @@ -168,7 +179,7 @@ func (s *s3) ListBucket(ctx context.Context, bucketName string, prefix *gofakes3 // gofakes3.ErrBucketAlreadyExists MUST be returned. func (s *s3) CreateBucket(ctx context.Context, name string) error { err := s.b.CreateBucket(ctx, name, api.CreateBucketOptions{}) - if err != nil && strings.Contains(err.Error(), api.ErrBucketExists.Error()) { + if utils.IsErr(err, api.ErrBucketExists) { return gofakes3.ErrBucketAlreadyExists } else if err != nil { return gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -182,7 +193,7 @@ func (s *s3) CreateBucket(ctx context.Context, name string) error { // TODO: backend could be improved to allow for checking specific dir in root. func (s *s3) BucketExists(ctx context.Context, name string) (bool, error) { _, err := s.b.Bucket(ctx, name) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return false, nil } else if err != nil { return false, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -202,9 +213,9 @@ func (s *s3) BucketExists(ctx context.Context, name string) (bool, error) { // atomically checking whether a bucket is empty. func (s *s3) DeleteBucket(ctx context.Context, name string) error { err := s.b.DeleteBucket(ctx, name) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotEmpty.Error()) { + if utils.IsErr(err, api.ErrBucketNotEmpty) { return gofakes3.ErrBucketNotEmpty - } else if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + } else if utils.IsErr(err, api.ErrBucketNotFound) { return gofakes3.BucketNotFound(name) } else if err != nil { return gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -242,9 +253,9 @@ func (s *s3) GetObject(ctx context.Context, bucketName, objectName string, range } res, err := s.w.GetObject(ctx, bucketName, objectName, opts) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return nil, gofakes3.BucketNotFound(bucketName) - } else if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + } else if utils.IsErr(err, api.ErrObjectNotFound) { return nil, gofakes3.KeyNotFound(objectName) } else if err != nil { return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -268,7 +279,14 @@ func (s *s3) GetObject(ctx context.Context, bucketName, objectName string, range res.Metadata["Content-Type"] = res.ContentType res.Metadata["Last-Modified"] = res.LastModified + // etag to bytes + etag, err := hex.DecodeString(res.Etag) + if err != nil { + return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) + } + return &gofakes3.Object{ + Hash: etag, Name: gofakes3.URLEncode(objectName), Metadata: res.Metadata, Size: res.Size, @@ -287,11 +305,10 @@ func (s *s3) GetObject(ctx context.Context, bucketName, objectName string, range // HeadObject should return a NotFound() error if the object does not // exist. func (s *s3) HeadObject(ctx context.Context, bucketName, objectName string) (*gofakes3.Object, error) { - res, err := s.b.Object(ctx, bucketName, objectName, api.GetObjectOptions{ - IgnoreDelim: true, - OnlyMetadata: true, + res, err := s.w.HeadObject(ctx, bucketName, objectName, api.HeadObjectOptions{ + IgnoreDelim: true, }) - if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + if utils.IsErr(err, api.ErrObjectNotFound) { return nil, gofakes3.KeyNotFound(objectName) } else if err != nil { return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -299,18 +316,25 @@ func (s *s3) HeadObject(ctx context.Context, bucketName, objectName string) (*go // set user metadata metadata := make(map[string]string) - for k, v := range res.Object.Metadata { + for k, v := range res.Metadata { metadata[amazonMetadataPrefix+k] = v } // decorate metadata - metadata["Content-Type"] = res.Object.MimeType - metadata["Last-Modified"] = res.Object.LastModified() + metadata["Content-Type"] = res.ContentType + metadata["Last-Modified"] = res.LastModified + + // etag to bytes + hash, err := hex.DecodeString(res.Etag) + if err != nil { + return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) + } return &gofakes3.Object{ + Hash: hash, Name: gofakes3.URLEncode(objectName), Metadata: metadata, - Size: res.Object.Size, + Size: res.Size, Contents: io.NopCloser(bytes.NewReader(nil)), }, nil } @@ -332,9 +356,9 @@ func (s *s3) HeadObject(ctx context.Context, bucketName, objectName string) (*go // isn't a null version, Amazon S3 does not remove any objects. func (s *s3) DeleteObject(ctx context.Context, bucketName, objectName string) (gofakes3.ObjectDeleteResult, error) { err := s.b.DeleteObject(ctx, bucketName, objectName, api.DeleteObjectOptions{}) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return gofakes3.ObjectDeleteResult{}, gofakes3.BucketNotFound(bucketName) - } else if err != nil && !strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + } else if utils.IsErr(err, api.ErrObjectNotFound) { return gofakes3.ObjectDeleteResult{}, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) } @@ -357,7 +381,7 @@ func (s *s3) PutObject(ctx context.Context, bucketName, key string, meta map[str } ur, err := s.w.UploadObject(ctx, input, bucketName, key, opts) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { return gofakes3.PutObjectResult{}, gofakes3.BucketNotFound(bucketName) } else if err != nil { return gofakes3.PutObjectResult{}, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) @@ -373,7 +397,7 @@ func (s *s3) DeleteMulti(ctx context.Context, bucketName string, objects ...stri var res gofakes3.MultiDeleteResult for _, objectName := range objects { err := s.b.DeleteObject(ctx, bucketName, objectName, api.DeleteObjectOptions{}) - if err != nil && !strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + if err != nil && !utils.IsErr(err, api.ErrObjectNotFound) { res.Error = append(res.Error, gofakes3.ErrorResult{ Key: objectName, Code: gofakes3.ErrInternal, @@ -502,7 +526,8 @@ func (s *s3) AbortMultipartUpload(ctx context.Context, bucket, object string, id return nil } -func (s *s3) CompleteMultipartUpload(ctx context.Context, bucket, object string, id gofakes3.UploadID, input *gofakes3.CompleteMultipartUploadRequest) (*gofakes3.CompleteMultipartUploadResult, error) { +func (s *s3) CompleteMultipartUpload(ctx context.Context, bucket, object string, id gofakes3.UploadID, meta map[string]string, input *gofakes3.CompleteMultipartUploadRequest) (*gofakes3.CompleteMultipartUploadResult, error) { + convertToSiaMetadataHeaders(meta) var parts []api.MultipartCompletedPart for _, part := range input.Parts { parts = append(parts, api.MultipartCompletedPart{ @@ -510,7 +535,9 @@ func (s *s3) CompleteMultipartUpload(ctx context.Context, bucket, object string, PartNumber: part.PartNumber, }) } - resp, err := s.b.CompleteMultipartUpload(ctx, bucket, "/"+object, string(id), parts) + resp, err := s.b.CompleteMultipartUpload(ctx, bucket, "/"+object, string(id), parts, api.CompleteMultipartOptions{ + Metadata: api.ExtractObjectUserMetadataFrom(meta), + }) if err != nil { return nil, gofakes3.ErrorMessage(gofakes3.ErrInternal, err.Error()) } diff --git a/s3/s3.go b/s3/s3.go index dc7ac664b..0ac1dbd49 100644 --- a/s3/s3.go +++ b/s3/s3.go @@ -36,7 +36,7 @@ type bus interface { Object(ctx context.Context, bucket, path string, opts api.GetObjectOptions) (res api.ObjectsResponse, err error) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) (err error) - CompleteMultipartUpload(ctx context.Context, bucket, path, uploadID string, parts []api.MultipartCompletedPart) (_ api.MultipartCompleteResponse, err error) + CompleteMultipartUpload(ctx context.Context, bucket, path, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (_ api.MultipartCompleteResponse, err error) CreateMultipartUpload(ctx context.Context, bucket, path string, opts api.CreateMultipartOptions) (api.MultipartCreateResponse, error) MultipartUploads(ctx context.Context, bucket, prefix, keyMarker, uploadIDMarker string, maxUploads int) (resp api.MultipartListUploadsResponse, _ error) MultipartUploadParts(ctx context.Context, bucket, object string, uploadID string, marker int, limit int64) (resp api.MultipartListPartsResponse, _ error) @@ -48,6 +48,7 @@ type bus interface { type worker interface { GetObject(ctx context.Context, bucket, path string, opts api.DownloadObjectOptions) (*api.GetObjectResponse, error) + HeadObject(ctx context.Context, bucket, path string, opts api.HeadObjectOptions) (*api.HeadObjectResponse, error) UploadObject(ctx context.Context, r io.Reader, bucket, path string, opts api.UploadObjectOptions) (*api.UploadObjectResponse, error) UploadMultipartUploadPart(ctx context.Context, r io.Reader, bucket, path, uploadID string, partNumber int, opts api.UploadMultipartUploadPartOptions) (*api.UploadMultipartUploadPartResponse, error) } diff --git a/stores/accounts.go b/stores/accounts.go index d519df9dd..69f4aeff8 100644 --- a/stores/accounts.go +++ b/stores/accounts.go @@ -55,7 +55,7 @@ func (a dbAccount) convert() api.Account { // Accounts returns all accounts from the db. func (s *SQLStore) Accounts(ctx context.Context) ([]api.Account, error) { var dbAccounts []dbAccount - if err := s.db.Find(&dbAccounts).Error; err != nil { + if err := s.db.WithContext(ctx).Find(&dbAccounts).Error; err != nil { return nil, err } accounts := make([]api.Account, len(dbAccounts)) @@ -69,8 +69,10 @@ func (s *SQLStore) Accounts(ctx context.Context) ([]api.Account, error) { // also sets the 'requires_sync' flag. That way, the autopilot will know to sync // all accounts after an unclean shutdown and the bus will know not to apply // drift. -func (s *SQLStore) SetUncleanShutdown() error { - return s.db.Model(&dbAccount{}). +func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error { + return s.db. + WithContext(ctx). + Model(&dbAccount{}). Where("TRUE"). Updates(map[string]interface{}{ "clean_shutdown": false, @@ -95,7 +97,7 @@ func (s *SQLStore) SaveAccounts(ctx context.Context, accounts []api.Account) err RequiresSync: acc.RequiresSync, } } - return s.db.Clauses(clause.OnConflict{ + return s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "account_id"}}, UpdateAll: true, }).Create(&dbAccounts).Error diff --git a/stores/autopilot.go b/stores/autopilot.go index 6dc88a692..5a5c5ed2d 100644 --- a/stores/autopilot.go +++ b/stores/autopilot.go @@ -34,6 +34,7 @@ func (c dbAutopilot) convert() api.Autopilot { func (s *SQLStore) Autopilots(ctx context.Context) ([]api.Autopilot, error) { var entities []dbAutopilot err := s.db. + WithContext(ctx). Model(&dbAutopilot{}). Find(&entities). Error @@ -51,6 +52,7 @@ func (s *SQLStore) Autopilots(ctx context.Context) ([]api.Autopilot, error) { func (s *SQLStore) Autopilot(ctx context.Context, id string) (api.Autopilot, error) { var entity dbAutopilot err := s.db. + WithContext(ctx). Model(&dbAutopilot{}). Where("identifier = ?", id). First(&entity). @@ -73,10 +75,12 @@ func (s *SQLStore) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error } // upsert - return s.db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "identifier"}}, - UpdateAll: true, - }).Create(&dbAutopilot{ + return s.db. + WithContext(ctx). + Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "identifier"}}, + UpdateAll: true, + }).Create(&dbAutopilot{ Identifier: ap.ID, Config: ap.Config, CurrentPeriod: ap.CurrentPeriod, diff --git a/stores/hostdb.go b/stores/hostdb.go index 37aa18ee8..fd23abf4a 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -87,12 +87,6 @@ type ( Hosts []dbHost `gorm:"many2many:host_allowlist_entry_hosts;constraint:OnDelete:CASCADE"` } - // dbHostAllowlistEntryHost is a join table between dbAllowlistEntry and dbHost. - dbHostAllowlistEntryHost struct { - DBAllowlistEntryID uint `gorm:"primaryKey"` - DBHostID uint `gorm:"primaryKey;index"` - } - // dbBlocklistEntry defines a table that stores the host blocklist. dbBlocklistEntry struct { Model @@ -100,12 +94,6 @@ type ( Hosts []dbHost `gorm:"many2many:host_blocklist_entry_hosts;constraint:OnDelete:CASCADE"` } - // dbHostBlocklistEntryHost is a join table between dbBlocklistEntry and dbHost. - dbHostBlocklistEntryHost struct { - DBBlocklistEntryID uint `gorm:"primaryKey"` - DBHostID uint `gorm:"primaryKey;index"` - } - dbConsensusInfo struct { Model CCID []byte @@ -278,15 +266,9 @@ func (dbHost) TableName() string { return "hosts" } // TableName implements the gorm.Tabler interface. func (dbAllowlistEntry) TableName() string { return "host_allowlist_entries" } -// TableName implements the gorm.Tabler interface. -func (dbHostAllowlistEntryHost) TableName() string { return "host_allowlist_entry_hosts" } - // TableName implements the gorm.Tabler interface. func (dbBlocklistEntry) TableName() string { return "host_blocklist_entries" } -// TableName implements the gorm.Tabler interface. -func (dbHostBlocklistEntryHost) TableName() string { return "host_blocklist_entry_hosts" } - // convert converts a host into a hostdb.Host. func (h dbHost) convert() hostdb.Host { var lastScan time.Time @@ -427,6 +409,7 @@ func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (hostdb.H var h dbHost tx := ss.db. + WithContext(ctx). Where(&dbHost{PublicKey: publicKey(hostKey)}). Preload("Allowlist"). Preload("Blocklist"). @@ -456,6 +439,7 @@ func (ss *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, var hostAddresses []hostdb.HostAddress err := ss.db. + WithContext(ctx). Model(&dbHost{}). Where("last_scan < ?", maxLastScan.UnixNano()). Offset(offset). @@ -546,6 +530,7 @@ func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures ui // fetch all hosts outside of the transaction var hosts []dbHost if err := ss.db. + WithContext(ctx). Model(&dbHost{}). Where("recent_downtime >= ? AND recent_scan_failures >= ?", maxDowntime, minRecentFailures). Find(&hosts). @@ -561,7 +546,7 @@ func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures ui // remove every host one by one var errs []error for _, h := range hosts { - if err := ss.retryTransaction(func(tx *gorm.DB) error { + if err := ss.retryTransaction(ctx, func(tx *gorm.DB) error { // fetch host contracts hcs, err := contractsForHost(tx, h) if err != nil { @@ -575,7 +560,7 @@ func (ss *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures ui } // archive host contracts - if err := archiveContracts(ctx, tx, hcs, toArchive); err != nil { + if err := archiveContracts(tx, hcs, toArchive); err != nil { return err } @@ -609,7 +594,7 @@ func (ss *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove // clear allowlist if clear { - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Where("TRUE").Delete(&dbAllowlistEntry{}).Error }) } @@ -624,7 +609,7 @@ func (ss *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove toDelete[i] = publicKey(entry) } - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { if len(toInsert) > 0 { if err := tx.Create(&toInsert).Error; err != nil { return err @@ -648,7 +633,7 @@ func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove // clear blocklist if clear { - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Where("TRUE").Delete(&dbBlocklistEntry{}).Error }) } @@ -658,7 +643,7 @@ func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove toInsert = append(toInsert, dbBlocklistEntry{Entry: entry}) } - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { if len(toInsert) > 0 { if err := tx.Create(&toInsert).Error; err != nil { return err @@ -676,6 +661,7 @@ func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove func (ss *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { var pubkeys []publicKey err = ss.db. + WithContext(ctx). Model(&dbAllowlistEntry{}). Pluck("entry", &pubkeys). Error @@ -688,6 +674,7 @@ func (ss *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.Public func (ss *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { err = ss.db. + WithContext(ctx). Model(&dbBlocklistEntry{}). Pluck("entry", &blocklist). Error @@ -719,7 +706,7 @@ func (ss *SQLStore) RecordHostScans(ctx context.Context, scans []hostdb.HostScan end = len(hks) } var batchHosts []dbHost - if err := ss.db.Where("public_key IN (?)", hks[i:end]). + if err := ss.db.WithContext(ctx).Where("public_key IN (?)", hks[i:end]). Find(&batchHosts).Error; err != nil { return err } @@ -732,7 +719,7 @@ func (ss *SQLStore) RecordHostScans(ctx context.Context, scans []hostdb.HostScan // Write the interactions and update to the hosts atomically within a single // transaction. - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { // Handle scans for _, scan := range scans { host, exists := hostMap[publicKey(scan.HostKey)] @@ -841,7 +828,7 @@ func (ss *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []ho end = len(hks) } var batchHosts []dbHost - if err := ss.db.Where("public_key IN (?)", hks[i:end]). + if err := ss.db.WithContext(ctx).Where("public_key IN (?)", hks[i:end]). Find(&batchHosts).Error; err != nil { return err } @@ -854,7 +841,7 @@ func (ss *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []ho // Write the interactions and update to the hosts atomically within a single // transaction. - return ss.retryTransaction(func(tx *gorm.DB) error { + return ss.retryTransaction(ctx, func(tx *gorm.DB) error { // Handle price table updates for _, ptu := range priceTableUpdate { host, exists := hostMap[publicKey(ptu.HostKey)] @@ -1086,7 +1073,7 @@ func updateBlocklist(tx *gorm.DB, hk types.PublicKey, allowlist []dbAllowlistEnt } func (s *SQLStore) ResetLostSectors(ctx context.Context, hk types.PublicKey) error { - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Model(&dbHost{}). Where("public_key", publicKey(hk)). Update("lost_sectors", 0). diff --git a/stores/metadata.go b/stores/metadata.go index 529d7ec89..c543695bd 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -32,8 +32,9 @@ const ( // 10/30 erasure coding and takes <1s to execute on an SSD in SQLite. refreshHealthBatchSize = 10000 + // sectorInsertionBatchSize is the number of sectors per batch when we + // upsert sectors. sectorInsertionBatchSize = 500 - sectorQueryBatchSize = 100 refreshHealthMinHealthValidity = 12 * time.Hour refreshHealthMaxHealthValidity = 72 * time.Hour @@ -480,6 +481,7 @@ func (raw rawObject) toSlabSlice() (slice object.SlabSlice, _ error) { func (s *SQLStore) Bucket(ctx context.Context, bucket string) (api.Bucket, error) { var b dbBucket err := s.db. + WithContext(ctx). Model(&dbBucket{}). Where("name = ?", bucket). Take(&b). @@ -498,7 +500,7 @@ func (s *SQLStore) Bucket(ctx context.Context, bucket string) (api.Bucket, error func (s *SQLStore) CreateBucket(ctx context.Context, bucket string, policy api.BucketPolicy) error { // Create bucket. - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { res := tx.Clauses(clause.OnConflict{ DoNothing: true, }). @@ -520,7 +522,7 @@ func (s *SQLStore) UpdateBucketPolicy(ctx context.Context, bucket string, policy if err != nil { return err } - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { return tx. Model(&dbBucket{}). Where("name", bucket). @@ -534,7 +536,7 @@ func (s *SQLStore) UpdateBucketPolicy(ctx context.Context, bucket string, policy func (s *SQLStore) DeleteBucket(ctx context.Context, bucket string) error { // Delete bucket. - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { var b dbBucket if err := tx.Take(&b, "name = ?", bucket).Error; errors.Is(err, gorm.ErrRecordNotFound) { return api.ErrBucketNotFound @@ -561,6 +563,7 @@ func (s *SQLStore) DeleteBucket(ctx context.Context, bucket string) error { func (s *SQLStore) ListBuckets(ctx context.Context) ([]api.Bucket, error) { var buckets []dbBucket err := s.db. + WithContext(ctx). Model(&dbBucket{}). Find(&buckets). Error @@ -583,10 +586,12 @@ func (s *SQLStore) ListBuckets(ctx context.Context) ([]api.Bucket, error) { // reduce locking and make sure all results are consistent, everything is done // within a single transaction. func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) (api.ObjectsStatsResponse, error) { + db := s.db.WithContext(ctx) + // fetch bucket id if a bucket was specified var bucketID uint if opts.Bucket != "" { - err := s.db.Model(&dbBucket{}).Select("id").Where("name = ?", opts.Bucket).Take(&bucketID).Error + err := db.Model(&dbBucket{}).Select("id").Where("name = ?", opts.Bucket).Take(&bucketID).Error if err != nil { return api.ObjectsStatsResponse{}, err } @@ -598,7 +603,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) MinHealth float64 TotalObjectsSize uint64 } - objInfoQuery := s.db. + objInfoQuery := db. Model(&dbObject{}). Select("COUNT(*) AS NumObjects, COALESCE(MIN(health), 1) as MinHealth, SUM(size) AS TotalObjectsSize") if opts.Bucket != "" { @@ -611,7 +616,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) // number of unfinished objects var unfinishedObjects uint64 - unfinishedObjectsQuery := s.db. + unfinishedObjectsQuery := db. Model(&dbMultipartUpload{}). Select("COUNT(*)") if opts.Bucket != "" { @@ -624,7 +629,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) // size of unfinished objects var totalUnfinishedObjectsSize uint64 - totalUnfinishedObjectsSizeQuery := s.db. + totalUnfinishedObjectsSizeQuery := db. Model(&dbMultipartPart{}). Joins("INNER JOIN multipart_uploads mu ON multipart_parts.db_multipart_upload_id = mu.id"). Select("COALESCE(SUM(size), 0)") @@ -637,7 +642,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) } var totalSectors int64 - totalSectorsQuery := s.db. + totalSectorsQuery := db. Table("slabs sla"). Select("COALESCE(SUM(total_shards), 0)"). Where("db_buffered_slab_id IS NULL") @@ -657,7 +662,7 @@ func (s *SQLStore) ObjectsStats(ctx context.Context, opts api.ObjectsStatsOpts) } var totalUploaded int64 - err = s.db. + err = db. Model(&dbContract{}). Select("COALESCE(SUM(size), 0)"). Scan(&totalUploaded). @@ -707,7 +712,7 @@ func (s *SQLStore) AddContract(ctx context.Context, c rhpv2.ContractRevision, co return api.ContractMetadata{}, err } var added dbContract - if err = s.retryTransaction(func(tx *gorm.DB) error { + if err = s.retryTransaction(ctx, func(tx *gorm.DB) error { added, err = addContract(tx, c, contractPrice, totalCost, startHeight, types.FileContractID{}, cs) return err }); err != nil { @@ -719,12 +724,14 @@ func (s *SQLStore) AddContract(ctx context.Context, c rhpv2.ContractRevision, co } func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api.ContractMetadata, error) { + db := s.db.WithContext(ctx) + // helper to check whether a contract set exists hasContractSet := func() error { if opts.ContractSet == "" { return nil } - err := s.db.Where("name", opts.ContractSet).Take(&dbContractSet{}).Error + err := db.Where("name", opts.ContractSet).Take(&dbContractSet{}).Error if errors.Is(err, gorm.ErrRecordNotFound) { return api.ErrContractSetNotFound } @@ -737,13 +744,13 @@ func (s *SQLStore) Contracts(ctx context.Context, opts api.ContractsOpts) ([]api Host dbHost `gorm:"embedded"` Name string } - tx := s.db + tx := db if opts.ContractSet == "" { // no filter, use all contracts tx = tx.Table("contracts") } else { // filter contracts by contract set first - tx = tx.Table("(?) contracts", s.db.Model(&dbContract{}). + tx = tx.Table("(?) contracts", db.Model(&dbContract{}). Select("contracts.*"). Joins("INNER JOIN hosts h ON h.id = contracts.host_id"). Joins("INNER JOIN contract_set_contracts csc ON csc.db_contract_id = contracts.id"). @@ -806,7 +813,7 @@ func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevis return api.ContractMetadata{}, err } var renewed dbContract - if err := s.retryTransaction(func(tx *gorm.DB) error { + if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { // Fetch contract we renew from. oldContract, err := contract(tx, fileContractID(renewedFrom)) if err != nil { @@ -846,7 +853,7 @@ func (s *SQLStore) AddRenewedContract(ctx context.Context, c rhpv2.ContractRevis func (s *SQLStore) AncestorContracts(ctx context.Context, id types.FileContractID, startHeight uint64) ([]api.ArchivedContract, error) { var ancestors []dbArchivedContract - err := s.db.Raw("WITH RECURSIVE ancestors AS (SELECT * FROM archived_contracts WHERE renewed_to = ? UNION ALL SELECT archived_contracts.* FROM ancestors, archived_contracts WHERE archived_contracts.renewed_to = ancestors.fcid) SELECT * FROM ancestors WHERE start_height >= ?", fileContractID(id), startHeight). + err := s.db.WithContext(ctx).Raw("WITH RECURSIVE ancestors AS (SELECT * FROM archived_contracts WHERE renewed_to = ? UNION ALL SELECT archived_contracts.* FROM ancestors, archived_contracts WHERE archived_contracts.renewed_to = ancestors.fcid) SELECT * FROM ancestors WHERE start_height >= ?", fileContractID(id), startHeight). Scan(&ancestors). Error if err != nil { @@ -877,8 +884,8 @@ func (s *SQLStore) ArchiveContracts(ctx context.Context, toArchive map[types.Fil } // archive them - if err := s.retryTransaction(func(tx *gorm.DB) error { - return archiveContracts(ctx, tx, cs, toArchive) + if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { + return archiveContracts(tx, cs, toArchive) }); err != nil { return err } @@ -890,6 +897,7 @@ func (s *SQLStore) ArchiveAllContracts(ctx context.Context, reason string) error // fetch contract ids var fcids []fileContractID if err := s.db. + WithContext(ctx). Model(&dbContract{}). Pluck("fcid", &fcids). Error; err != nil { @@ -920,6 +928,7 @@ func (s *SQLStore) ContractRoots(ctx context.Context, id types.FileContractID) ( var dbRoots []hash256 if err = s.db. + WithContext(ctx). Raw(` SELECT sec.root FROM contracts c @@ -938,7 +947,7 @@ WHERE c.fcid = ? func (s *SQLStore) ContractSets(ctx context.Context) ([]string, error) { var sets []string - err := s.db.Raw("SELECT name FROM contract_sets"). + err := s.db.WithContext(ctx).Raw("SELECT name FROM contract_sets"). Scan(&sets). Error return sets, err @@ -953,7 +962,7 @@ func (s *SQLStore) ContractSizes(ctx context.Context) (map[types.FileContractID] var nullContracts []size var dataContracts []size - if err := s.retryTransaction(func(tx *gorm.DB) error { + if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { // first, we fetch all contracts without sectors and consider their // entire size as prunable if err := tx. @@ -1003,6 +1012,7 @@ func (s *SQLStore) ContractSize(ctx context.Context, id types.FileContractID) (a } if err := s.db. + WithContext(ctx). Raw(` SELECT contract_size as size, CASE WHEN contract_size > sector_size THEN contract_size - sector_size ELSE 0 END as prunable FROM ( SELECT MAX(c.size) as contract_size, COUNT(cs.db_sector_id) * ? as sector_size FROM contracts c LEFT JOIN contract_sectors cs ON cs.db_contract_id = c.id WHERE c.fcid = ? @@ -1029,7 +1039,7 @@ func (s *SQLStore) SetContractSet(ctx context.Context, name string, contractIds var diff []fileContractID var nContractsAfter int - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { // fetch contract set var cs dbContractSet err := tx. @@ -1097,6 +1107,7 @@ func (s *SQLStore) SetContractSet(ctx context.Context, name string, contractIds func (s *SQLStore) RemoveContractSet(ctx context.Context, name string) error { return s.db. + WithContext(ctx). Where(dbContractSet{Name: name}). Delete(&dbContractSet{}). Error @@ -1106,6 +1117,7 @@ func (s *SQLStore) RenewedContract(ctx context.Context, renewedFrom types.FileCo var contract dbContract err = s.db. + WithContext(ctx). Where(&dbContract{ContractCommon: ContractCommon{RenewedFrom: fileContractID(renewedFrom)}}). Joins("Host"). Take(&contract). @@ -1126,6 +1138,7 @@ func (s *SQLStore) SearchObjects(ctx context.Context, bucket, substring string, var objects []api.ObjectMetadata err := s.db. + WithContext(ctx). Select("o.object_id as Name, o.size as Size, o.health as Health, o.mime_type as MimeType, o.etag as ETag, o.created_at as ModTime"). Model(&dbObject{}). Table("objects o"). @@ -1259,6 +1272,7 @@ FROM ( case api.ObjectSortByHealth: var markerHealth float64 if err = s.db. + WithContext(ctx). Raw(fmt.Sprintf(`SELECT Health FROM (%s WHERE oname >= ? ORDER BY oname LIMIT 1) as n`, objectsQuery), append(objectsQueryParams, marker)...). Scan(&markerHealth). Error; err != nil { @@ -1275,6 +1289,7 @@ FROM ( case api.ObjectSortBySize: var markerSize float64 if err = s.db. + WithContext(ctx). Raw(fmt.Sprintf(`SELECT Size FROM (%s WHERE oname >= ? ORDER BY oname LIMIT 1) as n`, objectsQuery), append(objectsQueryParams, marker)...). Scan(&markerSize). Error; err != nil { @@ -1315,6 +1330,7 @@ FROM ( parameters := append(append(objectsQueryParams, markerParams...), limit, offset) if err = s.db. + WithContext(ctx). Raw(query, parameters...). Scan(&rows). Error; err != nil { @@ -1335,8 +1351,8 @@ FROM ( } func (s *SQLStore) Object(ctx context.Context, bucket, path string) (obj api.Object, err error) { - err = s.db.Transaction(func(tx *gorm.DB) error { - obj, err = s.object(ctx, tx, bucket, path) + err = s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + obj, err = s.object(tx, bucket, path) return err }) return @@ -1367,7 +1383,7 @@ func (s *SQLStore) RecordContractSpending(ctx context.Context, records []api.Con } metrics := make([]api.ContractMetric, 0, len(squashedRecords)) for fcid, newSpending := range squashedRecords { - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { var contract dbContract err := tx.Model(&dbContract{}). Where("fcid = ?", fileContractID(fcid)). @@ -1470,11 +1486,11 @@ func fetchUsedContracts(tx *gorm.DB, usedContracts map[types.PublicKey]map[types } func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew string, force bool) error { - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { if force { // delete potentially existing object at destination if _, err := s.deleteObject(tx, bucket, keyNew); err != nil { - return err + return fmt.Errorf("RenameObject: failed to delete object: %w", err) } } tx = tx.Exec(`UPDATE objects SET object_id = ? WHERE object_id = ? AND ?`, keyNew, keyOld, sqlWhereBucket("objects", bucket)) @@ -1492,7 +1508,7 @@ func (s *SQLStore) RenameObject(ctx context.Context, bucket, keyOld, keyNew stri } func (s *SQLStore) RenameObjects(ctx context.Context, bucket, prefixOld, prefixNew string, force bool) error { - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { if force { // delete potentially existing objects at destination inner := tx.Raw("SELECT ? FROM objects WHERE object_id LIKE ? AND SUBSTR(object_id, 1, ?) = ? AND ?", @@ -1538,7 +1554,14 @@ func (s *SQLStore) AddPartialSlab(ctx context.Context, data []byte, minShards, t } func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath, dstPath, mimeType string, metadata api.ObjectUserMetadata) (om api.ObjectMetadata, err error) { - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { + if srcBucket != dstBucket || srcPath != dstPath { + _, err = s.deleteObject(tx, dstBucket, dstPath) + if err != nil { + return fmt.Errorf("CopyObject: failed to delete object: %w", err) + } + } + var srcObj dbObject err = tx.Where("objects.object_id = ? AND DBBucket.name = ?", srcPath, srcBucket). Joins("DBBucket"). @@ -1565,10 +1588,6 @@ func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath } return tx.Save(&srcObj).Error } - _, err = s.deleteObject(tx, dstBucket, dstPath) - if err != nil { - return fmt.Errorf("failed to delete object: %w", err) - } var srcSlices []dbSlice err = tx.Where("db_object_id = ?", srcObj.ID). @@ -1622,7 +1641,7 @@ func (s *SQLStore) CopyObject(ctx context.Context, srcBucket, dstBucket, srcPath func (s *SQLStore) DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) (int, error) { var deletedSectors int - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { // Fetch contract_sectors to delete. var sectors []dbContractSector err := tx.Raw(` @@ -1707,13 +1726,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, usedContracts := o.Contracts() // UpdateObject is ACID. - return s.retryTransaction(func(tx *gorm.DB) error { - // Fetch contract set. - var cs dbContractSet - if err := tx.Take(&cs, "name = ?", contractSet).Error; err != nil { - return fmt.Errorf("contract set %v not found: %w", contractSet, err) - } - + return s.retryTransaction(ctx, func(tx *gorm.DB) error { // Try to delete. We want to get rid of the object and its slices if it // exists. // @@ -1726,7 +1739,7 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, // object's metadata before trying to recreate it _, err := s.deleteObject(tx, bucket, path) if err != nil { - return fmt.Errorf("failed to delete object: %w", err) + return fmt.Errorf("UpdateObject: failed to delete object: %w", err) } // Insert a new object. @@ -1734,14 +1747,16 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, if err != nil { return fmt.Errorf("failed to marshal object key: %w", err) } + // fetch bucket id var bucketID uint - err = tx.Table("(SELECT id from buckets WHERE buckets.name = ?) bucket_id", bucket). + err = s.db.Table("(SELECT id from buckets WHERE buckets.name = ?) bucket_id", bucket). Take(&bucketID).Error if errors.Is(err, gorm.ErrRecordNotFound) { return fmt.Errorf("bucket %v not found: %w", bucket, api.ErrBucketNotFound) } else if err != nil { return fmt.Errorf("failed to fetch bucket id: %w", err) } + obj := dbObject{ DBBucketID: bucketID, ObjectID: path, @@ -1755,6 +1770,12 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, return fmt.Errorf("failed to create object: %w", err) } + // Fetch contract set. + var cs dbContractSet + if err := tx.Take(&cs, "name = ?", contractSet).Error; err != nil { + return fmt.Errorf("contract set %v not found: %w", contractSet, err) + } + // Fetch the used contracts. contracts, err := fetchUsedContracts(tx, usedContracts) if err != nil { @@ -1778,9 +1799,12 @@ func (s *SQLStore) UpdateObject(ctx context.Context, bucket, path, contractSet, func (s *SQLStore) RemoveObject(ctx context.Context, bucket, key string) error { var rowsAffected int64 var err error - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { rowsAffected, err = s.deleteObject(tx, bucket, key) - return err + if err != nil { + return fmt.Errorf("RemoveObject: failed to delete object: %w", err) + } + return nil }) if err != nil { return err @@ -1794,7 +1818,7 @@ func (s *SQLStore) RemoveObject(ctx context.Context, bucket, key string) error { func (s *SQLStore) RemoveObjects(ctx context.Context, bucket, prefix string) error { var rowsAffected int64 var err error - rowsAffected, err = s.deleteObjects(bucket, prefix) + rowsAffected, err = s.deleteObjects(ctx, bucket, prefix) if err != nil { return err } @@ -1844,7 +1868,7 @@ func (ss *SQLStore) UpdateSlab(ctx context.Context, s object.Slab, contractSet s usedContracts := s.Contracts() // Update slab. - return ss.retryTransaction(func(tx *gorm.DB) (err error) { + return ss.retryTransaction(ctx, func(tx *gorm.DB) (err error) { // update slab if err := tx.Model(&dbSlab{}). Where("key", key). @@ -1978,7 +2002,7 @@ LIMIT ? for { var rowsAffected int64 - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { var res *gorm.DB if isSQLite(s.db) { res = tx.Exec("UPDATE slabs SET health = inner.health, health_valid_until = (?) FROM (?) AS inner WHERE slabs.id=inner.id", sqlRandomTimestamp(s.db, now, refreshHealthMinHealthValidity, refreshHealthMaxHealthValidity), healthQuery) @@ -2028,7 +2052,7 @@ func (s *SQLStore) UnhealthySlabs(ctx context.Context, healthCutoff float64, set Health float64 } - if err := s.retryTransaction(func(tx *gorm.DB) error { + if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Select("slabs.key, slabs.health"). Joins("INNER JOIN contract_sets cs ON slabs.db_contract_set_id = cs.id"). Model(&dbSlab{}). @@ -2221,19 +2245,19 @@ func (s *SQLStore) createSlices(tx *gorm.DB, objID, multiPartID *uint, contractS } // object retrieves an object from the store. -func (s *SQLStore) object(ctx context.Context, tx *gorm.DB, bucket, path string) (api.Object, error) { +func (s *SQLStore) object(tx *gorm.DB, bucket, path string) (api.Object, error) { // fetch raw object data - raw, err := s.objectRaw(ctx, tx, bucket, path) + raw, err := s.objectRaw(tx, bucket, path) if errors.Is(err, gorm.ErrRecordNotFound) || len(raw) == 0 { return api.Object{}, api.ErrObjectNotFound } // hydrate raw object data - return s.objectHydrate(ctx, tx, bucket, path, raw) + return s.objectHydrate(tx, bucket, path, raw) } // objectHydrate hydrates a raw object and returns an api.Object. -func (s *SQLStore) objectHydrate(ctx context.Context, tx *gorm.DB, bucket, path string, obj rawObject) (api.Object, error) { +func (s *SQLStore) objectHydrate(tx *gorm.DB, bucket, path string, obj rawObject) (api.Object, error) { // parse object key var key object.EncryptionKey if err := key.UnmarshalBinary(obj[0].ObjectKey); err != nil { @@ -2288,7 +2312,7 @@ func (s *SQLStore) objectHydrate(ctx context.Context, tx *gorm.DB, bucket, path } // fetch object metadata - metadata, err := s.objectMetadata(ctx, tx, bucket, path) + metadata, err := s.objectMetadata(tx, bucket, path) if err != nil { return api.Object{}, err } @@ -2314,7 +2338,7 @@ func (s *SQLStore) objectHydrate(ctx context.Context, tx *gorm.DB, bucket, path // ObjectMetadata returns an object's metadata func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (api.Object, error) { var resp api.Object - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var obj dbObject err := tx.Model(&dbObject{}). Joins("INNER JOIN buckets b ON objects.db_bucket_id = b.id"). @@ -2327,7 +2351,7 @@ func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (api } else if err != nil { return err } - oum, err := s.objectMetadata(ctx, tx, bucket, path) + oum, err := s.objectMetadata(tx, bucket, path) if err != nil { return err } @@ -2347,7 +2371,7 @@ func (s *SQLStore) ObjectMetadata(ctx context.Context, bucket, path string) (api return resp, err } -func (s *SQLStore) objectMetadata(ctx context.Context, tx *gorm.DB, bucket, path string) (api.ObjectUserMetadata, error) { +func (s *SQLStore) objectMetadata(tx *gorm.DB, bucket, path string) (api.ObjectUserMetadata, error) { var rows []dbObjectUserMetadata err := tx. Model(&dbObjectUserMetadata{}). @@ -2378,12 +2402,12 @@ func newObjectMetadata(name, etag, mimeType string, health float64, modTime time } } -func (s *SQLStore) objectRaw(ctx context.Context, txn *gorm.DB, bucket string, path string) (rows rawObject, err error) { +func (s *SQLStore) objectRaw(txn *gorm.DB, bucket string, path string) (rows rawObject, err error) { // NOTE: we LEFT JOIN here because empty objects are valid and need to be // included in the result set, when we convert the rawObject before // returning it we'll check for SlabID and/or SectorID being 0 and act // accordingly - err = s.db. + err = txn. Select("o.id as ObjectID, o.health as ObjectHealth, sli.object_index as ObjectIndex, o.key as ObjectKey, o.object_id as ObjectName, o.size as ObjectSize, o.mime_type as ObjectMimeType, o.created_at as ObjectModTime, o.etag as ObjectETag, sli.object_index, sli.offset as SliceOffset, sli.length as SliceLength, sla.id as SlabID, sla.health as SlabHealth, sla.key as SlabKey, sla.min_shards as SlabMinShards, bs.id IS NOT NULL AS SlabBuffered, sec.slab_index as SectorIndex, sec.root as SectorRoot, sec.latest_host as LatestHost, c.fcid as FCID, h.public_key as HostKey"). Model(&dbObject{}). Table("objects o"). @@ -2403,40 +2427,9 @@ func (s *SQLStore) objectRaw(ctx context.Context, txn *gorm.DB, bucket string, p return } -func (s *SQLStore) objectHealth(ctx context.Context, tx *gorm.DB, objectID uint) (health float64, err error) { - if err = tx. - Select("objects.health"). - Model(&dbObject{}). - Table("objects"). - Where("id", objectID). - Scan(&health). - Error; errors.Is(err, gorm.ErrRecordNotFound) { - err = api.ErrObjectNotFound - } - return -} - // contract retrieves a contract from the store. func (s *SQLStore) contract(ctx context.Context, id fileContractID) (dbContract, error) { - return contract(s.db, id) -} - -// contracts retrieves all contracts in the given set. -func (s *SQLStore) contracts(ctx context.Context, set string) ([]dbContract, error) { - var cs dbContractSet - err := s.db. - Where(&dbContractSet{Name: set}). - Preload("Contracts.Host"). - Take(&cs). - Error - - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, fmt.Errorf("%w '%s'", api.ErrContractSetNotFound, set) - } else if err != nil { - return nil, err - } - - return cs.Contracts, nil + return contract(s.db.WithContext(ctx), id) } // PackedSlabsForUpload returns up to 'limit' packed slabs that are ready for @@ -2444,7 +2437,7 @@ func (s *SQLStore) contracts(ctx context.Context, set string) ([]dbContract, err // again. func (s *SQLStore) PackedSlabsForUpload(ctx context.Context, lockingDuration time.Duration, minShards, totalShards uint8, set string, limit int) ([]api.PackedSlab, error) { var contractSetID uint - if err := s.db.Raw("SELECT id FROM contract_sets WHERE name = ?", set). + if err := s.db.WithContext(ctx).Raw("SELECT id FROM contract_sets WHERE name = ?", set). Scan(&contractSetID).Error; err != nil { return nil, err } @@ -2458,7 +2451,7 @@ func (s *SQLStore) ObjectsBySlabKey(ctx context.Context, bucket string, slabKey return nil, err } - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Raw(` SELECT DISTINCT obj.object_id as Name, obj.size as Size, obj.mime_type as MimeType, sla.health as Health FROM slabs sla @@ -2494,7 +2487,7 @@ func (s *SQLStore) MarkPackedSlabsUploaded(ctx context.Context, slabs []api.Uplo } } var fileName string - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { for _, slab := range slabs { var err error fileName, err = s.markPackedSlabUploaded(tx, slab) @@ -2670,14 +2663,14 @@ func addContract(tx *gorm.DB, c rhpv2.ContractRevision, contractPrice, totalCost // archival reason // // NOTE: this function archives the contracts without setting a renewed ID -func archiveContracts(ctx context.Context, tx *gorm.DB, contracts []dbContract, toArchive map[types.FileContractID]string) error { +func archiveContracts(tx *gorm.DB, contracts []dbContract, toArchive map[types.FileContractID]string) error { var toInvalidate []fileContractID for _, contract := range contracts { toInvalidate = append(toInvalidate, contract.FCID) } // Invalidate the health on the slabs before deleting the contracts to avoid // breaking the relations beforehand. - if err := invalidateSlabHealthByFCID(ctx, tx, toInvalidate); err != nil { + if err := invalidateSlabHealthByFCID(tx, toInvalidate); err != nil { return fmt.Errorf("invalidating slab health failed: %w", err) } for _, contract := range contracts { @@ -2722,7 +2715,21 @@ AND slabs.db_buffered_slab_id IS NULL // without an obect after the deletion. That means in case of packed uploads, // the slab is only deleted when no more objects point to it. func (s *SQLStore) deleteObject(tx *gorm.DB, bucket string, path string) (int64, error) { - tx = tx.Where("object_id = ? AND ?", path, sqlWhereBucket("objects", bucket)). + // check if the object exists first to avoid unnecessary locking for the + // common case + var objID uint + resp := tx.Model(&dbObject{}). + Where("object_id = ? AND ?", path, sqlWhereBucket("objects", bucket)). + Select("id"). + Limit(1). + Scan(&objID) + if err := resp.Error; err != nil { + return 0, err + } else if resp.RowsAffected == 0 { + return 0, nil + } + + tx = tx.Where("id", objID). Delete(&dbObject{}) if tx.Error != nil { return 0, tx.Error @@ -2740,12 +2747,12 @@ func (s *SQLStore) deleteObject(tx *gorm.DB, bucket string, path string) (int64, // deletion goes from largest to smallest. That's because the batch size is // dynamically increased and the smaller objects get the faster we can delete // them meaning it makes sense to increase the batch size over time. -func (s *SQLStore) deleteObjects(bucket string, path string) (numDeleted int64, _ error) { +func (s *SQLStore) deleteObjects(ctx context.Context, bucket string, path string) (numDeleted int64, _ error) { batchSizeIdx := 0 for { var duration time.Duration var rowsAffected int64 - if err := s.retryTransaction(func(tx *gorm.DB) error { + if err := s.retryTransaction(ctx, func(tx *gorm.DB) error { start := time.Now() res := tx.Exec(` DELETE FROM objects @@ -2787,7 +2794,7 @@ func (s *SQLStore) deleteObjects(bucket string, path string) (numDeleted int64, return numDeleted, nil } -func invalidateSlabHealthByFCID(ctx context.Context, tx *gorm.DB, fcids []fileContractID) error { +func invalidateSlabHealthByFCID(tx *gorm.DB, fcids []fileContractID) error { if len(fcids) == 0 { return nil } @@ -2811,22 +2818,18 @@ func invalidateSlabHealthByFCID(ctx context.Context, tx *gorm.DB, fcids []fileCo } else if resp.RowsAffected < refreshHealthBatchSize { break // done } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(time.Second): - } + time.Sleep(time.Second) } return nil } func (s *SQLStore) invalidateSlabHealthByFCID(ctx context.Context, fcids []fileContractID) error { - return s.retryTransaction(func(tx *gorm.DB) error { - return invalidateSlabHealthByFCID(ctx, tx, fcids) + return s.retryTransaction(ctx, func(tx *gorm.DB) error { + return invalidateSlabHealthByFCID(tx, fcids) }) } +// nolint:unparam func sqlConcat(db *gorm.DB, a, b string) string { if isSQLite(db) { return fmt.Sprintf("%s || %s", a, b) @@ -2841,6 +2844,7 @@ func sqlRandomTimestamp(db *gorm.DB, now time.Time, min, max time.Duration) clau return gorm.Expr("FLOOR(? + RAND() * (? - ?))", now.Add(min).Unix(), int(max.Seconds()), int(min.Seconds())) } +// nolint:unparam func sqlWhereBucket(objTable string, bucket string) clause.Expr { return gorm.Expr(fmt.Sprintf("%s.db_bucket_id = (SELECT id FROM buckets WHERE buckets.name = ?)", objTable), bucket) } @@ -2872,7 +2876,7 @@ func (s *SQLStore) ListObjects(ctx context.Context, bucket, prefix, sortBy, sort } var rows []rawObjectMetadata if err := s.db. - Select("o.object_id as Name, o.size as Size, o.health as Health, o.mime_type as mimeType, o.created_at as ModTime"). + Select("o.object_id as Name, o.size as Size, o.health as Health, o.mime_type as MimeType, o.created_at as ModTime, o.etag as ETag"). Model(&dbObject{}). Table("objects o"). Joins("INNER JOIN buckets b ON o.db_bucket_id = b.id"). diff --git a/stores/metadata_test.go b/stores/metadata_test.go index f5461147c..c6ac1cd52 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -10,6 +10,7 @@ import ( "reflect" "sort" "strings" + "sync" "testing" "time" @@ -23,10 +24,10 @@ import ( "lukechampine.com/frand" ) -func generateMultisigUC(m, n uint64, salt string) types.UnlockConditions { +func randomMultisigUC() types.UnlockConditions { uc := types.UnlockConditions{ - PublicKeys: make([]types.UnlockKey, n), - SignaturesRequired: uint64(m), + PublicKeys: make([]types.UnlockKey, 2), + SignaturesRequired: 1, } for i := range uc.PublicKeys { uc.PublicKeys[i].Algorithm = types.SpecifierEd25519 @@ -224,7 +225,7 @@ func TestSQLContractStore(t *testing.T) { } // Create random unlock conditions for the host. - uc := generateMultisigUC(1, 2, "salt") + uc := randomMultisigUC() uc.PublicKeys[1].Key = hk[:] uc.Timelock = 192837 @@ -519,11 +520,11 @@ func TestRenewedContract(t *testing.T) { } // Create random unlock conditions for the hosts. - uc := generateMultisigUC(1, 2, "salt") + uc := randomMultisigUC() uc.PublicKeys[1].Key = hk[:] uc.Timelock = 192837 - uc2 := generateMultisigUC(1, 2, "salt") + uc2 := randomMultisigUC() uc2.PublicKeys[1].Key = hk2[:] uc2.Timelock = 192837 @@ -873,7 +874,7 @@ func TestArchiveContracts(t *testing.T) { } func testContractRevision(fcid types.FileContractID, hk types.PublicKey) rhpv2.ContractRevision { - uc := generateMultisigUC(1, 2, "salt") + uc := randomMultisigUC() uc.PublicKeys[1].Key = hk[:] uc.Timelock = 192837 return rhpv2.ContractRevision{ @@ -1867,7 +1868,7 @@ func TestUnhealthySlabsNoContracts(t *testing.T) { // delete the sector - we manually invalidate the slabs for the contract // before deletion. - err = invalidateSlabHealthByFCID(context.Background(), ss.db, []fileContractID{fileContractID(fcid1)}) + err = invalidateSlabHealthByFCID(ss.db, []fileContractID{fileContractID(fcid1)}) if err != nil { t.Fatal(err) } @@ -3285,7 +3286,7 @@ func TestBucketObjects(t *testing.T) { // See if we can fetch the object by slab. var ec object.EncryptionKey - if obj, err := ss.objectRaw(context.Background(), ss.db, b1, "/bar"); err != nil { + if obj, err := ss.objectRaw(ss.db, b1, "/bar"); err != nil { t.Fatal(err) } else if err := ec.UnmarshalBinary(obj[0].SlabKey); err != nil { t.Fatal(err) @@ -3390,7 +3391,7 @@ func TestMarkSlabUploadedAfterRenew(t *testing.T) { // renew the contract. fcidRenewed := types.FileContractID{2, 2, 2, 2, 2} - uc := generateMultisigUC(1, 2, "salt") + uc := randomMultisigUC() rev := rhpv2.ContractRevision{ Revision: types.FileContractRevision{ ParentID: fcidRenewed, @@ -3491,6 +3492,13 @@ func TestListObjects(t *testing.T) { {"/foo", "size", "ASC", "", []api.ObjectMetadata{{Name: "/foo/bar", Size: 1, Health: 1}, {Name: "/foo/bat", Size: 2, Health: 1}, {Name: "/foo/baz/quux", Size: 3, Health: .75}, {Name: "/foo/baz/quuz", Size: 4, Health: .5}}}, {"/foo", "size", "DESC", "", []api.ObjectMetadata{{Name: "/foo/baz/quuz", Size: 4, Health: .5}, {Name: "/foo/baz/quux", Size: 3, Health: .75}, {Name: "/foo/bat", Size: 2, Health: 1}, {Name: "/foo/bar", Size: 1, Health: 1}}}, } + // set common fields + for i := range tests { + for j := range tests[i].want { + tests[i].want[j].ETag = testETag + tests[i].want[j].MimeType = testMimeType + } + } for _, test := range tests { res, err := ss.ListObjects(ctx, api.DefaultBucketName, test.prefix, test.sortBy, test.sortDir, "", -1) if err != nil { @@ -4545,3 +4553,104 @@ func TestTypeCurrency(t *testing.T) { } } } + +// TestUpdateObjectParallel calls UpdateObject from multiple threads in parallel +// while retries are disabled to make sure calling the same method from multiple +// threads won't cause deadlocks. +// +// NOTE: This test only covers the optimistic case of inserting objects without +// overwriting them. As soon as combining deletions and insertions within the +// same transaction, deadlocks become more likely due to the gap locks MySQL +// uses. +func TestUpdateObjectParallel(t *testing.T) { + cfg := defaultTestSQLStoreConfig + + dbURI, _, _, _ := DBConfigFromEnv() + if dbURI == "" { + // it's pretty much impossile to optimise for both sqlite and mysql at + // the same time so we skip this test for SQLite for now + // TODO: once we moved away from gorm and implement separate interfaces + // for SQLite and MySQL, we have more control over the used queries and + // can revisit this + t.SkipNow() + } + ss := newTestSQLStore(t, cfg) + ss.retryTransactionIntervals = []time.Duration{0} // don't retry + defer ss.Close() + + // create 2 hosts + hks, err := ss.addTestHosts(2) + if err != nil { + t.Fatal(err) + } + hk1, hk2 := hks[0], hks[1] + + // create 2 contracts + fcids, _, err := ss.addTestContracts(hks) + if err != nil { + t.Fatal(err) + } + fcid1, fcid2 := fcids[0], fcids[1] + + c := make(chan string) + ctx, cancel := context.WithCancel(context.Background()) + work := func() { + t.Helper() + defer cancel() + for name := range c { + // create an object + obj := object.Object{ + Key: object.GenerateEncryptionKey(), + Slabs: []object.SlabSlice{ + { + Slab: object.Slab{ + Health: 1.0, + Key: object.GenerateEncryptionKey(), + MinShards: 1, + Shards: newTestShards(hk1, fcid1, frand.Entropy256()), + }, + Offset: 10, + Length: 100, + }, + { + Slab: object.Slab{ + Health: 1.0, + Key: object.GenerateEncryptionKey(), + MinShards: 2, + Shards: newTestShards(hk2, fcid2, frand.Entropy256()), + }, + Offset: 20, + Length: 200, + }, + }, + } + + // update the object + if err := ss.UpdateObject(context.Background(), api.DefaultBucketName, name, testContractSet, testETag, testMimeType, testMetadata, obj); err != nil { + t.Error(err) + return + } + } + } + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + work() + wg.Done() + }() + } + + // create 1000 objects and then overwrite them + for i := 0; i < 1000; i++ { + select { + case c <- fmt.Sprintf("object-%d", i): + case <-ctx.Done(): + return + } + } + + close(c) + wg.Wait() +} diff --git a/stores/metrics.go b/stores/metrics.go index 333ed8a42..c8369f630 100644 --- a/stores/metrics.go +++ b/stores/metrics.go @@ -450,11 +450,11 @@ func (s *SQLStore) contractMetrics(ctx context.Context, start time.Time, n uint6 if opts.ContractID == (types.FileContractID{}) && opts.HostKey == (types.PublicKey{}) { // if neither contract nor host filters were set, we return the // aggregate spending for each period - metrics, err = s.findAggregatedContractPeriods(start, n, interval) + metrics, err = s.findAggregatedContractPeriods(ctx, start, n, interval) } else { // otherwise we return the first metric for each period like we usually // do - err = s.findPeriods(dbContractMetric{}.TableName(), &metrics, start, n, interval, whereExpr) + err = s.findPeriods(ctx, dbContractMetric{}.TableName(), &metrics, start, n, interval, whereExpr) } if err != nil { return nil, fmt.Errorf("failed to fetch contract metrics: %w", err) @@ -478,7 +478,7 @@ func (s *SQLStore) contractPruneMetrics(ctx context.Context, start time.Time, n } var metrics []dbContractPruneMetric - err := s.findPeriods(dbContractPruneMetric{}.TableName(), &metrics, start, n, interval, whereExpr) + err := s.findPeriods(ctx, dbContractPruneMetric{}.TableName(), &metrics, start, n, interval, whereExpr) if err != nil { return nil, fmt.Errorf("failed to fetch contract metrics: %w", err) } @@ -498,7 +498,7 @@ func (s *SQLStore) contractSetChurnMetrics(ctx context.Context, start time.Time, whereExpr = gorm.Expr("? AND reason = ?", whereExpr, opts.Reason) } var metrics []dbContractSetChurnMetric - err := s.findPeriods(dbContractSetChurnMetric{}.TableName(), &metrics, start, n, interval, whereExpr) + err := s.findPeriods(ctx, dbContractSetChurnMetric{}.TableName(), &metrics, start, n, interval, whereExpr) if err != nil { return nil, fmt.Errorf("failed to fetch contract set churn metrics: %w", err) } @@ -515,7 +515,7 @@ func (s *SQLStore) contractSetMetrics(ctx context.Context, start time.Time, n ui } var metrics []dbContractSetMetric - err := s.findPeriods(dbContractSetMetric{}.TableName(), &metrics, start, n, interval, whereExpr) + err := s.findPeriods(ctx, dbContractSetMetric{}.TableName(), &metrics, start, n, interval, whereExpr) if err != nil { return nil, fmt.Errorf("failed to fetch contract set metrics: %w", err) } @@ -536,7 +536,7 @@ func normaliseTimestamp(start time.Time, interval time.Duration, t unixTimeMS) u return unixTimeMS(time.UnixMilli(normalizedMS)) } -func (s *SQLStore) findAggregatedContractPeriods(start time.Time, n uint64, interval time.Duration) ([]dbContractMetric, error) { +func (s *SQLStore) findAggregatedContractPeriods(ctx context.Context, start time.Time, n uint64, interval time.Duration) ([]dbContractMetric, error) { if n > api.MetricMaxIntervals { return nil, api.ErrMaxIntervalsExceeded } @@ -548,7 +548,7 @@ func (s *SQLStore) findAggregatedContractPeriods(start time.Time, n uint64, inte } var metricsWithPeriod []metricWithPeriod - err := s.dbMetrics.Transaction(func(tx *gorm.DB) error { + err := s.dbMetrics.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var fcids []fileContractID if err := tx.Raw("SELECT DISTINCT fcid FROM contracts WHERE contracts.timestamp >= ? AND contracts.timestamp < ?", unixTimeMS(start), unixTimeMS(end)). Scan(&fcids).Error; err != nil { @@ -599,12 +599,12 @@ func (s *SQLStore) findAggregatedContractPeriods(start time.Time, n uint64, inte // split into intervals and the row with the lowest timestamp for each interval // is returned. The result is then joined with the original table to retrieve // only the metrics we want. -func (s *SQLStore) findPeriods(table string, dst interface{}, start time.Time, n uint64, interval time.Duration, whereExpr clause.Expr) error { +func (s *SQLStore) findPeriods(ctx context.Context, table string, dst interface{}, start time.Time, n uint64, interval time.Duration, whereExpr clause.Expr) error { if n > api.MetricMaxIntervals { return api.ErrMaxIntervalsExceeded } end := start.Add(time.Duration(n) * interval) - return s.dbMetrics.Raw(fmt.Sprintf(` + return s.dbMetrics.WithContext(ctx).Raw(fmt.Sprintf(` WITH RECURSIVE periods AS ( SELECT ? AS period_start UNION ALL @@ -637,7 +637,7 @@ func (s *SQLStore) findPeriods(table string, dst interface{}, start time.Time, n } func (s *SQLStore) walletMetrics(ctx context.Context, start time.Time, n uint64, interval time.Duration, opts api.WalletMetricsQueryOpts) (metrics []dbWalletMetric, err error) { - err = s.findPeriods(dbWalletMetric{}.TableName(), &metrics, start, n, interval, gorm.Expr("TRUE")) + err = s.findPeriods(ctx, dbWalletMetric{}.TableName(), &metrics, start, n, interval, gorm.Expr("TRUE")) if err != nil { return nil, fmt.Errorf("failed to fetch wallet metrics: %w", err) } @@ -660,7 +660,7 @@ func (s *SQLStore) performanceMetrics(ctx context.Context, start time.Time, n ui } var metrics []dbPerformanceMetric - err := s.findPeriods(dbPerformanceMetric{}.TableName(), &metrics, start, n, interval, whereExpr) + err := s.findPeriods(ctx, dbPerformanceMetric{}.TableName(), &metrics, start, n, interval, whereExpr) if err != nil { return nil, fmt.Errorf("failed to fetch performance metrics: %w", err) } diff --git a/stores/migrations.go b/stores/migrations.go index cb0a38b18..6ccc75964 100644 --- a/stores/migrations.go +++ b/stores/migrations.go @@ -3,9 +3,9 @@ package stores import ( "errors" "fmt" - "strings" "github.com/go-gormigrate/gormigrate/v2" + "go.sia.tech/renterd/internal/utils" "go.uber.org/zap" "gorm.io/gorm" ) @@ -32,7 +32,7 @@ func performMigrations(db *gorm.DB, logger *zap.SugaredLogger) error { ID: "00002_prune_slabs_trigger", Migrate: func(tx *gorm.DB) error { err := performMigration(tx, dbIdentifier, "00002_prune_slabs_trigger", logger) - if err != nil && strings.Contains(err.Error(), errMySQLNoSuperPrivilege.Error()) { + if utils.IsErr(err, errMySQLNoSuperPrivilege) { logger.Warn("migration 00002_prune_slabs_trigger requires the user to have the SUPER privilege to register triggers") } return err @@ -56,13 +56,19 @@ func performMigrations(db *gorm.DB, logger *zap.SugaredLogger) error { return performMigration(tx, dbIdentifier, "00005_zero_size_object_health", logger) }, }, + { + ID: "00006_idx_objects_created_at", + Migrate: func(tx *gorm.DB) error { + return performMigration(tx, dbIdentifier, "00006_idx_objects_created_at", logger) + }, + }, } // Create migrator. m := gormigrate.New(db, gormigrate.DefaultOptions, migrations) // Set init function. - m.InitSchema(initSchema(db, dbIdentifier, logger)) + m.InitSchema(initSchema(dbIdentifier, logger)) // Perform migrations. if err := m.Migrate(); err != nil { diff --git a/stores/migrations/mysql/main/migration_00006_idx_objects_created_at.sql b/stores/migrations/mysql/main/migration_00006_idx_objects_created_at.sql new file mode 100644 index 000000000..310c9a1c3 --- /dev/null +++ b/stores/migrations/mysql/main/migration_00006_idx_objects_created_at.sql @@ -0,0 +1 @@ +CREATE INDEX `idx_objects_created_at` ON `objects`(`created_at`); diff --git a/stores/migrations/mysql/main/schema.sql b/stores/migrations/mysql/main/schema.sql index a5ed86807..68b42ae47 100644 --- a/stores/migrations/mysql/main/schema.sql +++ b/stores/migrations/mysql/main/schema.sql @@ -331,6 +331,7 @@ CREATE TABLE `objects` ( KEY `idx_objects_health` (`health`), KEY `idx_objects_etag` (`etag`), KEY `idx_objects_size` (`size`), + KEY `idx_objects_created_at` (`created_at`), CONSTRAINT `fk_objects_db_bucket` FOREIGN KEY (`db_bucket_id`) REFERENCES `buckets` (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci; diff --git a/stores/migrations/sqlite/main/migration_00006_idx_objects_created_at.sql b/stores/migrations/sqlite/main/migration_00006_idx_objects_created_at.sql new file mode 100644 index 000000000..310c9a1c3 --- /dev/null +++ b/stores/migrations/sqlite/main/migration_00006_idx_objects_created_at.sql @@ -0,0 +1 @@ +CREATE INDEX `idx_objects_created_at` ON `objects`(`created_at`); diff --git a/stores/migrations/sqlite/main/schema.sql b/stores/migrations/sqlite/main/schema.sql index 8d7afeaa1..9875e81e3 100644 --- a/stores/migrations/sqlite/main/schema.sql +++ b/stores/migrations/sqlite/main/schema.sql @@ -52,6 +52,7 @@ CREATE INDEX `idx_objects_health` ON `objects`(`health`); CREATE INDEX `idx_objects_object_id` ON `objects`(`object_id`); CREATE INDEX `idx_objects_size` ON `objects`(`size`); CREATE UNIQUE INDEX `idx_object_bucket` ON `objects`(`db_bucket_id`,`object_id`); +CREATE INDEX `idx_objects_created_at` ON `objects`(`created_at`); -- dbMultipartUpload CREATE TABLE `multipart_uploads` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`key` blob,`upload_id` text NOT NULL,`object_id` text NOT NULL,`db_bucket_id` integer NOT NULL,`mime_type` text,CONSTRAINT `fk_multipart_uploads_db_bucket` FOREIGN KEY (`db_bucket_id`) REFERENCES `buckets`(`id`) ON DELETE CASCADE); diff --git a/stores/migrations_metrics.go b/stores/migrations_metrics.go index fc3164bee..25895c4f2 100644 --- a/stores/migrations_metrics.go +++ b/stores/migrations_metrics.go @@ -27,7 +27,7 @@ func performMetricsMigrations(db *gorm.DB, logger *zap.SugaredLogger) error { m := gormigrate.New(db, gormigrate.DefaultOptions, migrations) // Set init function. - m.InitSchema(initSchema(db, dbIdentifier, logger)) + m.InitSchema(initSchema(dbIdentifier, logger)) // Perform migrations. if err := m.Migrate(); err != nil { diff --git a/stores/migrations_utils.go b/stores/migrations_utils.go index 46d7f3dc4..0692b367f 100644 --- a/stores/migrations_utils.go +++ b/stores/migrations_utils.go @@ -10,7 +10,7 @@ import ( // initSchema is executed only on a clean database. Otherwise the individual // migrations are executed. -func initSchema(db *gorm.DB, name string, logger *zap.SugaredLogger) gormigrate.InitSchemaFunc { +func initSchema(name string, logger *zap.SugaredLogger) gormigrate.InitSchemaFunc { return func(tx *gorm.DB) error { logger.Infof("initializing '%s' schema", name) diff --git a/stores/multipart.go b/stores/multipart.go index 864503455..3da5f7992 100644 --- a/stores/multipart.go +++ b/stores/multipart.go @@ -56,7 +56,7 @@ func (s *SQLStore) CreateMultipartUpload(ctx context.Context, bucket, path strin return api.MultipartCreateResponse{}, err } var uploadID string - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { // Get bucket id. var bucketID uint err := tx.Table("(SELECT id from buckets WHERE buckets.name = ?) bucket_id", bucket). @@ -108,7 +108,7 @@ func (s *SQLStore) AddMultipartPart(ctx context.Context, bucket, path, contractS } } } - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { // Fetch contract set. var cs dbContractSet if err := tx.Take(&cs, "name = ?", contractSet).Error; err != nil { @@ -160,7 +160,7 @@ func (s *SQLStore) AddMultipartPart(ctx context.Context, bucket, path, contractS } func (s *SQLStore) MultipartUpload(ctx context.Context, uploadID string) (resp api.MultipartUpload, err error) { - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { var dbUpload dbMultipartUpload err := tx. Model(&dbMultipartUpload{}). @@ -201,7 +201,7 @@ func (s *SQLStore) MultipartUploads(ctx context.Context, bucket, prefix, keyMark prefixExpr = gorm.Expr("SUBSTR(object_id, 1, ?) = ?", utf8.RuneCountInString(prefix), prefix) } - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { var dbUploads []dbMultipartUpload err := tx. Model(&dbMultipartUpload{}). @@ -243,7 +243,7 @@ func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object stri limit++ } - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { var dbParts []dbMultipartPart err := tx. Model(&dbMultipartPart{}). @@ -277,7 +277,7 @@ func (s *SQLStore) MultipartUploadParts(ctx context.Context, bucket, object stri } func (s *SQLStore) AbortMultipartUpload(ctx context.Context, bucket, path string, uploadID string) error { - return s.retryTransaction(func(tx *gorm.DB) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { // delete multipart upload optimistically res := tx. Where("upload_id", uploadID). @@ -313,7 +313,7 @@ func (s *SQLStore) AbortMultipartUpload(ctx context.Context, bucket, path string }) } -func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path string, uploadID string, parts []api.MultipartCompletedPart) (_ api.MultipartCompleteResponse, err error) { +func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path string, uploadID string, parts []api.MultipartCompletedPart, opts api.CompleteMultipartOptions) (_ api.MultipartCompleteResponse, err error) { // Sanity check input parts. if !sort.SliceIsSorted(parts, func(i, j int) bool { return parts[i].PartNumber < parts[j].PartNumber @@ -326,7 +326,13 @@ func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path str } } var eTag string - err = s.retryTransaction(func(tx *gorm.DB) error { + err = s.retryTransaction(ctx, func(tx *gorm.DB) error { + // Delete potentially existing object. + _, err := s.deleteObject(tx, bucket, path) + if err != nil { + return fmt.Errorf("failed to delete object: %w", err) + } + // Find multipart upload. var mu dbMultipartUpload err = tx.Where("upload_id = ?", uploadID). @@ -347,12 +353,6 @@ func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path str return fmt.Errorf("bucket name mismatch: %v != %v: %w", mu.DBBucket.Name, bucket, api.ErrBucketNotFound) } - // Delete potentially existing object. - _, err := s.deleteObject(tx, bucket, path) - if err != nil { - return fmt.Errorf("failed to delete object: %w", err) - } - // Sort the parts. sort.Slice(mu.Parts, func(i, j int) bool { return mu.Parts[i].PartNumber < mu.Parts[j].PartNumber @@ -434,6 +434,14 @@ func (s *SQLStore) CompleteMultipartUpload(ctx context.Context, bucket, path str } } + // Create new metadata. + if len(opts.Metadata) > 0 { + err = s.createUserMetadata(tx, obj.ID, opts.Metadata) + if err != nil { + return fmt.Errorf("failed to create metadata: %w", err) + } + } + // Update user metadata. if err := tx. Model(&dbObjectUserMetadata{}). diff --git a/stores/multipart_test.go b/stores/multipart_test.go index 37b294418..50272fcda 100644 --- a/stores/multipart_test.go +++ b/stores/multipart_test.go @@ -91,7 +91,7 @@ func TestMultipartUploadWithUploadPackingRegression(t *testing.T) { t.Fatal(err) } else if nSlicesBefore == 0 { t.Fatal("expected some slices") - } else if _, err = ss.CompleteMultipartUpload(ctx, api.DefaultBucketName, objName, resp.UploadID, parts); err != nil { + } else if _, err = ss.CompleteMultipartUpload(ctx, api.DefaultBucketName, objName, resp.UploadID, parts, api.CompleteMultipartOptions{}); err != nil { t.Fatal(err) } else if err := ss.db.Model(&dbSlice{}).Count(&nSlicesAfter).Error; err != nil { t.Fatal(err) diff --git a/stores/slabbuffer.go b/stores/slabbuffer.go index e1c7290ea..2d16c8e33 100644 --- a/stores/slabbuffer.go +++ b/stores/slabbuffer.go @@ -204,7 +204,7 @@ func (mgr *SlabBufferManager) AddPartialSlab(ctx context.Context, data []byte, m // If there is still data left, create a new buffer. if len(data) > 0 { var sb *SlabBuffer - err = mgr.s.retryTransaction(func(tx *gorm.DB) error { + err = mgr.s.retryTransaction(ctx, func(tx *gorm.DB) error { sb, err = createSlabBuffer(tx, contractSet, mgr.dir, minShards, totalShards) return err }) diff --git a/stores/sql.go b/stores/sql.go index 5d9d9cea8..f62dba97f 100644 --- a/stores/sql.go +++ b/stores/sql.go @@ -2,7 +2,6 @@ package stores import ( "context" - "database/sql" "embed" "errors" "fmt" @@ -446,7 +445,7 @@ func (ss *SQLStore) applyUpdates(force bool) error { ss.logger.Error(fmt.Sprintf("failed to fetch blocklist, err: %v", err)) } - err := ss.retryTransaction(func(tx *gorm.DB) (err error) { + err := ss.retryTransaction(context.Background(), func(tx *gorm.DB) (err error) { if len(ss.unappliedAnnouncements) > 0 { if err = insertAnnouncements(tx, ss.unappliedAnnouncements); err != nil { return fmt.Errorf("%w; failed to insert %d announcements", err, len(ss.unappliedAnnouncements)) @@ -514,9 +513,10 @@ func (ss *SQLStore) applyUpdates(force bool) error { return nil } -func (s *SQLStore) retryTransaction(fc func(tx *gorm.DB) error, opts ...*sql.TxOptions) error { +func (s *SQLStore) retryTransaction(ctx context.Context, fc func(tx *gorm.DB) error) error { abortRetry := func(err error) bool { if err == nil || + errors.Is(err, context.Canceled) || errors.Is(err, gorm.ErrRecordNotFound) || errors.Is(err, errInvalidNumberOfShards) || errors.Is(err, errShardRootChanged) || @@ -539,7 +539,7 @@ func (s *SQLStore) retryTransaction(fc func(tx *gorm.DB) error, opts ...*sql.TxO } var err error for i := 0; i < len(s.retryTransactionIntervals); i++ { - err = s.db.Transaction(fc, opts...) + err = s.db.WithContext(ctx).Transaction(fc) if abortRetry(err) { return err } @@ -566,10 +566,10 @@ func initConsensusInfo(db *gorm.DB) (dbConsensusInfo, modules.ConsensusChangeID, return ci, ccid, nil } -func (s *SQLStore) ResetConsensusSubscription() error { +func (s *SQLStore) ResetConsensusSubscription(ctx context.Context) error { // empty tables and reinit consensus_infos var ci dbConsensusInfo - err := s.retryTransaction(func(tx *gorm.DB) error { + err := s.retryTransaction(ctx, func(tx *gorm.DB) error { if err := s.db.Exec("DELETE FROM consensus_infos").Error; err != nil { return err } else if err := s.db.Exec("DELETE FROM siacoin_elements").Error; err != nil { diff --git a/stores/sql_test.go b/stores/sql_test.go index 776e3e10e..842f3c9df 100644 --- a/stores/sql_test.go +++ b/stores/sql_test.go @@ -107,8 +107,8 @@ func newTestSQLStore(t *testing.T, cfg testSQLStoreConfig) *testSQLStore { conn = NewMySQLConnection(dbUser, dbPassword, dbURI, dbName) connMetrics = NewMySQLConnection(dbUser, dbPassword, dbURI, dbMetricsName) } else if cfg.persistent { - conn = NewSQLiteConnection(filepath.Join(cfg.dir, "db.sqlite")) - connMetrics = NewSQLiteConnection(filepath.Join(cfg.dir, "metrics.sqlite")) + conn = NewSQLiteConnection(filepath.Join(dir, "db.sqlite")) + connMetrics = NewSQLiteConnection(filepath.Join(dir, "metrics.sqlite")) } else { conn = NewEphemeralSQLiteConnection(dbName) connMetrics = NewEphemeralSQLiteConnection(dbMetricsName) @@ -292,7 +292,7 @@ func TestConsensusReset(t *testing.T) { }) // Reset the consensus. - if err := ss.ResetConsensusSubscription(); err != nil { + if err := ss.ResetConsensusSubscription(context.Background()); err != nil { t.Fatal(err) } diff --git a/stores/webhooks.go b/stores/webhooks.go index f3fc26057..4db325698 100644 --- a/stores/webhooks.go +++ b/stores/webhooks.go @@ -1,6 +1,8 @@ package stores import ( + "context" + "go.sia.tech/renterd/webhooks" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -20,8 +22,8 @@ func (dbWebhook) TableName() string { return "webhooks" } -func (s *SQLStore) DeleteWebhook(wb webhooks.Webhook) error { - return s.retryTransaction(func(tx *gorm.DB) error { +func (s *SQLStore) DeleteWebhook(ctx context.Context, wb webhooks.Webhook) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { res := tx.Exec("DELETE FROM webhooks WHERE module = ? AND event = ? AND url = ?", wb.Module, wb.Event, wb.URL) if res.Error != nil { @@ -33,8 +35,8 @@ func (s *SQLStore) DeleteWebhook(wb webhooks.Webhook) error { }) } -func (s *SQLStore) AddWebhook(wb webhooks.Webhook) error { - return s.retryTransaction(func(tx *gorm.DB) error { +func (s *SQLStore) AddWebhook(ctx context.Context, wb webhooks.Webhook) error { + return s.retryTransaction(ctx, func(tx *gorm.DB) error { return tx.Clauses(clause.OnConflict{ DoNothing: true, }).Create(&dbWebhook{ @@ -45,9 +47,9 @@ func (s *SQLStore) AddWebhook(wb webhooks.Webhook) error { }) } -func (s *SQLStore) Webhooks() ([]webhooks.Webhook, error) { +func (s *SQLStore) Webhooks(ctx context.Context) ([]webhooks.Webhook, error) { var dbWebhooks []dbWebhook - if err := s.db.Find(&dbWebhooks).Error; err != nil { + if err := s.db.WithContext(ctx).Find(&dbWebhooks).Error; err != nil { return nil, err } var whs []webhooks.Webhook diff --git a/stores/webhooks_test.go b/stores/webhooks_test.go index ad1973125..b306eef2c 100644 --- a/stores/webhooks_test.go +++ b/stores/webhooks_test.go @@ -1,6 +1,7 @@ package stores import ( + "context" "testing" "github.com/google/go-cmp/cmp" @@ -23,10 +24,10 @@ func TestWebhooks(t *testing.T) { } // Add hook. - if err := ss.AddWebhook(wh1); err != nil { + if err := ss.AddWebhook(context.Background(), wh1); err != nil { t.Fatal(err) } - whs, err := ss.Webhooks() + whs, err := ss.Webhooks(context.Background()) if err != nil { t.Fatal(err) } else if len(whs) != 1 { @@ -36,10 +37,10 @@ func TestWebhooks(t *testing.T) { } // Add it again. Should be a no-op. - if err := ss.AddWebhook(wh1); err != nil { + if err := ss.AddWebhook(context.Background(), wh1); err != nil { t.Fatal(err) } - whs, err = ss.Webhooks() + whs, err = ss.Webhooks(context.Background()) if err != nil { t.Fatal(err) } else if len(whs) != 1 { @@ -49,10 +50,10 @@ func TestWebhooks(t *testing.T) { } // Add another. - if err := ss.AddWebhook(wh2); err != nil { + if err := ss.AddWebhook(context.Background(), wh2); err != nil { t.Fatal(err) } - whs, err = ss.Webhooks() + whs, err = ss.Webhooks(context.Background()) if err != nil { t.Fatal(err) } else if len(whs) != 2 { @@ -64,10 +65,10 @@ func TestWebhooks(t *testing.T) { } // Remove one. - if err := ss.DeleteWebhook(wh1); err != nil { + if err := ss.DeleteWebhook(context.Background(), wh1); err != nil { t.Fatal(err) } - whs, err = ss.Webhooks() + whs, err = ss.Webhooks(context.Background()) if err != nil { t.Fatal(err) } else if len(whs) != 1 { diff --git a/webhooks/webhooks.go b/webhooks/webhooks.go index e3d388de8..665f8a2c4 100644 --- a/webhooks/webhooks.go +++ b/webhooks/webhooks.go @@ -19,9 +19,9 @@ var ErrWebhookNotFound = errors.New("Webhook not found") type ( WebhookStore interface { - DeleteWebhook(wh Webhook) error - AddWebhook(wh Webhook) error - Webhooks() ([]Webhook, error) + DeleteWebhook(ctx context.Context, wh Webhook) error + AddWebhook(ctx context.Context, wh Webhook) error + Webhooks(ctx context.Context) ([]Webhook, error) } Broadcaster interface { @@ -122,10 +122,10 @@ func (m *Manager) Close() error { return nil } -func (m *Manager) Delete(wh Webhook) error { +func (m *Manager) Delete(ctx context.Context, wh Webhook) error { m.mu.Lock() defer m.mu.Unlock() - if err := m.store.DeleteWebhook(wh); errors.Is(err, gorm.ErrRecordNotFound) { + if err := m.store.DeleteWebhook(ctx, wh); errors.Is(err, gorm.ErrRecordNotFound) { return ErrWebhookNotFound } else if err != nil { return err @@ -157,7 +157,7 @@ func (m *Manager) Info() ([]Webhook, []WebhookQueueInfo) { return hooks, queueInfos } -func (m *Manager) Register(wh Webhook) error { +func (m *Manager) Register(ctx context.Context, wh Webhook) error { ctx, cancel := context.WithTimeout(m.shutdownCtx, webhookTimeout) defer cancel() @@ -170,7 +170,7 @@ func (m *Manager) Register(wh Webhook) error { } // Add Webhook. - if err := m.store.AddWebhook(wh); err != nil { + if err := m.store.AddWebhook(ctx, wh); err != nil { return err } m.mu.Lock() @@ -214,11 +214,6 @@ func (w Webhook) String() string { } func NewManager(logger *zap.SugaredLogger, store WebhookStore) (*Manager, error) { - hooks, err := store.Webhooks() - if err != nil { - return nil, err - } - shutdownCtx, shutdownCtxCancel := context.WithCancel(context.Background()) m := &Manager{ logger: logger.Named("webhooks"), @@ -230,7 +225,10 @@ func NewManager(logger *zap.SugaredLogger, store WebhookStore) (*Manager, error) queues: make(map[string]*eventQueue), webhooks: make(map[string]Webhook), } - + hooks, err := store.Webhooks(shutdownCtx) + if err != nil { + return nil, err + } for _, hook := range hooks { m.webhooks[hook.String()] = hook } diff --git a/worker/client/client.go b/worker/client/client.go index 410e4c66e..6ef70f338 100644 --- a/worker/client/client.go +++ b/worker/client/client.go @@ -81,12 +81,10 @@ func (c *Client) DownloadStats() (resp api.DownloadStatsResponse, err error) { func (c *Client) HeadObject(ctx context.Context, bucket, path string, opts api.HeadObjectOptions) (*api.HeadObjectResponse, error) { c.c.Custom("HEAD", fmt.Sprintf("/objects/%s", path), nil, nil) - if strings.HasSuffix(path, "/") { - return nil, errors.New("the given path is a directory, HEAD can only be performed on objects") - } - values := url.Values{} values.Set("bucket", url.QueryEscape(bucket)) + opts.Apply(values) + path = api.ObjectPathEscape(path) path += "?" + values.Encode() // TODO: support HEAD in jape client @@ -325,6 +323,7 @@ func parseObjectResponseHeaders(header http.Header) (api.HeadObjectResponse, err return api.HeadObjectResponse{ ContentType: header.Get("Content-Type"), + Etag: trimEtag(header.Get("ETag")), LastModified: header.Get("Last-Modified"), Range: r, Size: size, @@ -347,3 +346,8 @@ func sizeFromSeeker(r io.Reader) (int64, error) { } return size, nil } + +func trimEtag(etag string) string { + etag = strings.TrimPrefix(etag, "\"") + return strings.TrimSuffix(etag, "\"") +} diff --git a/worker/download.go b/worker/download.go index 3a58bbc98..83d4bec3e 100644 --- a/worker/download.go +++ b/worker/download.go @@ -7,13 +7,13 @@ import ( "fmt" "io" "math" - "strings" "sync" "time" rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.sia.tech/renterd/stats" "go.uber.org/zap" @@ -454,7 +454,7 @@ func (mgr *downloadManager) numDownloaders() int { // in the partial slab buffer. func (mgr *downloadManager) fetchPartialSlab(ctx context.Context, key object.EncryptionKey, offset, length uint32) ([]byte, *object.Slab, error) { data, err := mgr.os.FetchPartialSlab(ctx, key, offset, length) - if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + if utils.IsErr(err, api.ErrObjectNotFound) { // Check if slab was already uploaded. slab, err := mgr.os.Slab(ctx, key) if err != nil { @@ -495,11 +495,11 @@ func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata) host := mgr.hm.Host(c.HostKey, c.ID, c.SiamuxAddr) downloader := newDownloader(mgr.shutdownCtx, host) mgr.downloaders[c.HostKey] = downloader - go downloader.processQueue(mgr.hm) + go downloader.processQueue() } } -func (mgr *downloadManager) newSlabDownload(ctx context.Context, slice object.SlabSlice, migration bool) *slabDownload { +func (mgr *downloadManager) newSlabDownload(slice object.SlabSlice, migration bool) *slabDownload { // calculate the offset and length offset, length := slice.SectorRegion() @@ -529,7 +529,7 @@ func (mgr *downloadManager) newSlabDownload(ctx context.Context, slice object.Sl func (mgr *downloadManager) downloadSlab(ctx context.Context, slice object.SlabSlice, migration bool) ([][]byte, bool, error) { // prepare new download - slab := mgr.newSlabDownload(ctx, slice, migration) + slab := mgr.newSlabDownload(slice, migration) // execute download return slab.download(ctx) diff --git a/worker/downloader.go b/worker/downloader.go index 24be245fc..46dac61e3 100644 --- a/worker/downloader.go +++ b/worker/downloader.go @@ -245,7 +245,7 @@ func (d *downloader) processBatch(batch []*sectorDownloadReq) chan struct{} { return doneChan } -func (d *downloader) processQueue(hp HostManager) { +func (d *downloader) processQueue() { outer: for { // wait for work diff --git a/worker/rhpv2.go b/worker/rhpv2.go index 02cdce4ff..9f05904a4 100644 --- a/worker/rhpv2.go +++ b/worker/rhpv2.go @@ -277,7 +277,7 @@ func (w *worker) FetchSignedRevision(ctx context.Context, hostIP string, hostKey func (w *worker) PruneContract(ctx context.Context, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64) (deleted, remaining uint64, err error) { err = w.withContractLock(ctx, fcid, lockingPriorityPruning, func() error { return w.withTransportV2(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { - return w.withRevisionV2(ctx, defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { + return w.withRevisionV2(defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { // perform gouging checks gc, err := GougingCheckerFromContext(ctx, false) if err != nil { @@ -510,7 +510,7 @@ func (w *worker) deleteContractRoots(t *rhpv2.Transport, rev *rhpv2.ContractRevi func (w *worker) FetchContractRoots(ctx context.Context, hostIP string, hostKey types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64) (roots []types.Hash256, err error) { err = w.withTransportV2(ctx, hostKey, hostIP, func(t *rhpv2.Transport) error { - return w.withRevisionV2(ctx, defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { + return w.withRevisionV2(defaultLockTimeout, t, hostKey, fcid, lastKnownRevisionNumber, func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) (err error) { gc, err := GougingCheckerFromContext(ctx, false) if err != nil { return err @@ -641,7 +641,7 @@ func (w *worker) withTransportV2(ctx context.Context, hostKey types.PublicKey, h return fn(t) } -func (w *worker) withRevisionV2(ctx context.Context, lockTimeout time.Duration, t *rhpv2.Transport, hk types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64, fn func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) error) error { +func (w *worker) withRevisionV2(lockTimeout time.Duration, t *rhpv2.Transport, hk types.PublicKey, fcid types.FileContractID, lastKnownRevisionNumber uint64, fn func(t *rhpv2.Transport, rev rhpv2.ContractRevision, settings rhpv2.HostSettings) error) error { renterKey := w.deriveRenterKey(hk) // execute lock RPC diff --git a/worker/rhpv3.go b/worker/rhpv3.go index 9c280f2bd..8db6dc9d5 100644 --- a/worker/rhpv3.go +++ b/worker/rhpv3.go @@ -10,7 +10,6 @@ import ( "math" "math/big" "net" - "strings" "sync" "time" @@ -20,6 +19,7 @@ import ( "go.sia.tech/mux/v1" "go.sia.tech/renterd/api" "go.sia.tech/renterd/hostdb" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/siad/crypto" "go.uber.org/zap" ) @@ -47,6 +47,12 @@ const ( ) var ( + // errHost is used to wrap rpc errors returned by the host. + errHost = errors.New("host responded with error") + + // errTransport is used to wrap rpc errors caused by the transport. + errTransport = errors.New("transport error") + // errBalanceInsufficient occurs when a withdrawal failed because the // account balance was insufficient. errBalanceInsufficient = errors.New("ephemeral account balance was insufficient") @@ -83,31 +89,42 @@ var ( errWithdrawalExpired = errors.New("withdrawal request expired") ) -func isBalanceInsufficient(err error) bool { return isError(err, errBalanceInsufficient) } -func isBalanceMaxExceeded(err error) bool { return isError(err, errBalanceMaxExceeded) } +// IsErrHost indicates whether an error was returned by a host as part of an RPC. +func IsErrHost(err error) bool { + return utils.IsErr(err, errHost) +} + +func isBalanceInsufficient(err error) bool { return utils.IsErr(err, errBalanceInsufficient) } +func isBalanceMaxExceeded(err error) bool { return utils.IsErr(err, errBalanceMaxExceeded) } func isClosedStream(err error) bool { - return isError(err, mux.ErrClosedStream) || isError(err, net.ErrClosed) + return utils.IsErr(err, mux.ErrClosedStream) || utils.IsErr(err, net.ErrClosed) } -func isInsufficientFunds(err error) bool { return isError(err, ErrInsufficientFunds) } -func isPriceTableExpired(err error) bool { return isError(err, errPriceTableExpired) } -func isPriceTableGouging(err error) bool { return isError(err, errPriceTableGouging) } -func isPriceTableNotFound(err error) bool { return isError(err, errPriceTableNotFound) } +func isInsufficientFunds(err error) bool { return utils.IsErr(err, ErrInsufficientFunds) } +func isPriceTableExpired(err error) bool { return utils.IsErr(err, errPriceTableExpired) } +func isPriceTableGouging(err error) bool { return utils.IsErr(err, errPriceTableGouging) } +func isPriceTableNotFound(err error) bool { return utils.IsErr(err, errPriceTableNotFound) } func isSectorNotFound(err error) bool { - return isError(err, errSectorNotFound) || isError(err, errSectorNotFoundOld) + return utils.IsErr(err, errSectorNotFound) || utils.IsErr(err, errSectorNotFoundOld) } -func isWithdrawalsInactive(err error) bool { return isError(err, errWithdrawalsInactive) } -func isWithdrawalExpired(err error) bool { return isError(err, errWithdrawalExpired) } +func isWithdrawalsInactive(err error) bool { return utils.IsErr(err, errWithdrawalsInactive) } +func isWithdrawalExpired(err error) bool { return utils.IsErr(err, errWithdrawalExpired) } -func isError(err error, target error) bool { - if err == nil { - return err == target +// wrapRPCErr extracts the innermost error, wraps it in either a errHost or +// errTransport and finally wraps it using the provided fnName. +func wrapRPCErr(err *error, fnName string) { + if *err == nil { + return + } + innerErr := *err + for errors.Unwrap(innerErr) != nil { + innerErr = errors.Unwrap(innerErr) } - // compare error first - if errors.Is(err, target) { - return true + if errors.As(*err, new(*rhpv3.RPCError)) { + *err = fmt.Errorf("%w: '%w'", errHost, innerErr) + } else { + *err = fmt.Errorf("%w: '%w'", errTransport, innerErr) } - // then compare the string in case the error was returned by a host - return strings.Contains(strings.ToLower(err.Error()), strings.ToLower(target.Error())) + *err = fmt.Errorf("%s: %w", fnName, *err) } // transportV3 is a reference-counted wrapper for rhpv3.Transport. @@ -125,6 +142,26 @@ type streamV3 struct { *rhpv3.Stream } +func (s *streamV3) ReadResponse(resp rhpv3.ProtocolObject, maxLen uint64) (err error) { + defer wrapRPCErr(&err, "ReadResponse") + return s.Stream.ReadResponse(resp, maxLen) +} + +func (s *streamV3) WriteResponse(resp rhpv3.ProtocolObject) (err error) { + defer wrapRPCErr(&err, "WriteResponse") + return s.Stream.WriteResponse(resp) +} + +func (s *streamV3) ReadRequest(req rhpv3.ProtocolObject, maxLen uint64) (err error) { + defer wrapRPCErr(&err, "ReadRequest") + return s.Stream.ReadRequest(req, maxLen) +} + +func (s *streamV3) WriteRequest(rpcID types.Specifier, req rhpv3.ProtocolObject) (err error) { + defer wrapRPCErr(&err, "WriteRequest") + return s.Stream.WriteRequest(rpcID, req) +} + // Close closes the stream and cancels the goroutine launched by DialStream. func (s *streamV3) Close() error { s.cancel() @@ -177,7 +214,7 @@ type transportPoolV3 struct { pool map[string]*transportV3 } -func newTransportPoolV3(w *worker) *transportPoolV3 { +func newTransportPoolV3() *transportPoolV3 { return &transportPoolV3{ pool: make(map[string]*transportV3), } @@ -365,7 +402,7 @@ func (w *worker) initTransportPool() { if w.transportPoolV3 != nil { panic("transport pool already initialized") // developer error } - w.transportPoolV3 = newTransportPoolV3(w) + w.transportPoolV3 = newTransportPoolV3() } // ForHost returns an account to use for a given host. If the account diff --git a/worker/rhpv3_test.go b/worker/rhpv3_test.go new file mode 100644 index 000000000..83f605807 --- /dev/null +++ b/worker/rhpv3_test.go @@ -0,0 +1,34 @@ +package worker + +import ( + "errors" + "fmt" + "testing" + + rhpv3 "go.sia.tech/core/rhp/v3" +) + +func TestWrapRPCErr(t *testing.T) { + // host error + err := fmt.Errorf("ReadResponse: %w", &rhpv3.RPCError{ + Description: "some host error", + }) + if err.Error() != "ReadResponse: some host error" { + t.Fatal("unexpected error:", err) + } + wrapRPCErr(&err, "ReadResponse") + if err.Error() != "ReadResponse: host responded with error: 'some host error'" { + t.Fatal("unexpected error:", err) + } else if !errors.Is(err, errHost) { + t.Fatalf("expected error to be wrapped with %v, got %v", errHost, err) + } + + // transport error + err = fmt.Errorf("ReadResponse: %w", errors.New("some transport error")) + wrapRPCErr(&err, "ReadResponse") + if err.Error() != "ReadResponse: transport error: 'some transport error'" { + t.Fatal("unexpected error:", err) + } else if !errors.Is(err, errTransport) { + t.Fatalf("expected error to be wrapped with %v, got %v", errHost, err) + } +} diff --git a/worker/upload.go b/worker/upload.go index c5e86a166..d146b920e 100644 --- a/worker/upload.go +++ b/worker/upload.go @@ -2,6 +2,8 @@ package worker import ( "context" + "crypto/md5" + "encoding/hex" "errors" "fmt" "io" @@ -390,6 +392,11 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a // create the object o := object.NewObject(up.ec) + // create the md5 hasher for the etag + // NOTE: we use md5 since it's s3 compatible and clients expect it to be md5 + hasher := md5.New() + r = io.TeeReader(r, hasher) + // create the cipher reader cr, err := o.Encrypt(r, up.encryptionOffset) if err != nil { @@ -397,7 +404,7 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a } // create the upload - upload, err := mgr.newUpload(ctx, up.rs.TotalShards, contracts, up.bh, lockPriority) + upload, err := mgr.newUpload(up.rs.TotalShards, contracts, up.bh, lockPriority) if err != nil { return false, "", err } @@ -520,7 +527,7 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a } // compute etag - eTag = o.ComputeETag() + eTag = hex.EncodeToString(hasher.Sum(nil)) // add partial slabs if len(partialSlab) > 0 { @@ -558,7 +565,7 @@ func (mgr *uploadManager) UploadPackedSlab(ctx context.Context, rs api.Redundanc shards := encryptPartialSlab(ps.Data, ps.Key, uint8(rs.MinShards), uint8(rs.TotalShards)) // create the upload - upload, err := mgr.newUpload(ctx, len(shards), contracts, bh, lockPriority) + upload, err := mgr.newUpload(len(shards), contracts, bh, lockPriority) if err != nil { return err } @@ -603,7 +610,7 @@ func (mgr *uploadManager) UploadShards(ctx context.Context, s *object.Slab, shar defer cancel() // create the upload - upload, err := mgr.newUpload(ctx, len(shards), contracts, bh, lockPriority) + upload, err := mgr.newUpload(len(shards), contracts, bh, lockPriority) if err != nil { return err } @@ -675,7 +682,7 @@ func (mgr *uploadManager) candidates(allowed map[types.PublicKey]struct{}) (cand return } -func (mgr *uploadManager) newUpload(ctx context.Context, totalShards int, contracts []api.ContractMetadata, bh uint64, lockPriority int) (*upload, error) { +func (mgr *uploadManager) newUpload(totalShards int, contracts []api.ContractMetadata, bh uint64, lockPriority int) (*upload, error) { mgr.mu.Lock() defer mgr.mu.Unlock() diff --git a/worker/uploader.go b/worker/uploader.go index 28b04033d..403accbc8 100644 --- a/worker/uploader.go +++ b/worker/uploader.go @@ -11,6 +11,7 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/stats" "go.uber.org/zap" ) @@ -298,7 +299,7 @@ func (u *uploader) tryRecomputeStats() { func (u *uploader) tryRefresh(ctx context.Context) bool { // fetch the renewed contract renewed, err := u.cs.RenewedContract(ctx, u.ContractID()) - if isError(err, api.ErrContractNotFound) || isError(err, context.Canceled) { + if utils.IsErr(err, api.ErrContractNotFound) || utils.IsErr(err, context.Canceled) { return false } else if err != nil { u.logger.Errorf("failed to fetch renewed contract %v, err: %v", u.ContractID(), err) diff --git a/worker/worker.go b/worker/worker.go index 9e4dacdd2..0868c347c 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -25,6 +25,7 @@ import ( "go.sia.tech/renterd/api" "go.sia.tech/renterd/build" "go.sia.tech/renterd/hostdb" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/renterd/object" "go.sia.tech/renterd/webhooks" "go.sia.tech/renterd/worker/client" @@ -42,7 +43,6 @@ const ( lockingPriorityActiveContractRevision = 100 lockingPriorityRenew = 80 - lockingPriorityPriceTable = 60 lockingPriorityFunding = 40 lockingPrioritySyncing = 30 lockingPriorityPruning = 20 @@ -860,19 +860,24 @@ func (w *worker) objectsHandlerHEAD(jc jape.Context) { if jc.DecodeForm("bucket", &bucket) != nil { return } + var ignoreDelim bool + if jc.DecodeForm("ignoreDelim", &ignoreDelim) != nil { + return + } // parse path path := jc.PathParam("path") - if path == "" || strings.HasSuffix(path, "/") { + if !ignoreDelim && (path == "" || strings.HasSuffix(path, "/")) { jc.Error(errors.New("HEAD requests can only be performed on objects, not directories"), http.StatusBadRequest) return } // fetch object metadata res, err := w.bus.Object(jc.Request.Context(), bucket, path, api.GetObjectOptions{ + IgnoreDelim: ignoreDelim, OnlyMetadata: true, }) - if errors.Is(err, api.ErrObjectNotFound) { + if utils.IsErr(err, api.ErrObjectNotFound) { jc.Error(err, http.StatusNotFound) return } else if err != nil { @@ -945,7 +950,7 @@ func (w *worker) objectsHandlerGET(jc jape.Context) { path := jc.PathParam("path") res, err := w.bus.Object(ctx, bucket, path, opts) - if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + if utils.IsErr(err, api.ErrObjectNotFound) { jc.Error(err, http.StatusNotFound) return } else if jc.Check("couldn't get object or entries", err) != nil { @@ -1035,7 +1040,7 @@ func (w *worker) objectsHandlerPUT(jc jape.Context) { // return early if the bucket does not exist _, err = w.bus.Bucket(ctx, bucket) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { jc.Error(fmt.Errorf("bucket '%s' not found; %w", bucket, err), http.StatusNotFound) return } @@ -1098,7 +1103,7 @@ func (w *worker) objectsHandlerPUT(jc jape.Context) { if err := jc.Check("couldn't upload object", err); err != nil { if err != nil { w.logger.Error(err) - if !errors.Is(err, ErrShuttingDown) && !errors.Is(err, errUploadInterrupted) { + if !errors.Is(err, ErrShuttingDown) && !errors.Is(err, errUploadInterrupted) && !errors.Is(err, context.Canceled) { w.registerAlert(newUploadFailedAlert(bucket, path, up.ContractSet, mimeType, rs.MinShards, rs.TotalShards, len(contracts), up.UploadPacking, false, err)) } } @@ -1154,7 +1159,7 @@ func (w *worker) multipartUploadHandlerPUT(jc jape.Context) { // return early if the bucket does not exist _, err = w.bus.Bucket(ctx, bucket) - if err != nil && strings.Contains(err.Error(), api.ErrBucketNotFound.Error()) { + if utils.IsErr(err, api.ErrBucketNotFound) { jc.Error(fmt.Errorf("bucket '%s' not found; %w", bucket, err), http.StatusNotFound) return } @@ -1197,7 +1202,7 @@ func (w *worker) multipartUploadHandlerPUT(jc jape.Context) { // fetch upload from bus upload, err := w.bus.MultipartUpload(ctx, uploadID) - if isError(err, api.ErrMultipartUploadNotFound) { + if utils.IsErr(err, api.ErrMultipartUploadNotFound) { jc.Error(err, http.StatusNotFound) return } else if jc.Check("failed to fetch multipart upload", err) != nil { @@ -1257,7 +1262,7 @@ func (w *worker) objectsHandlerDELETE(jc jape.Context) { return } err := w.bus.DeleteObject(jc.Request.Context(), bucket, jc.PathParam("path"), api.DeleteObjectOptions{Batch: batch}) - if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) { + if utils.IsErr(err, api.ErrObjectNotFound) { jc.Error(err, http.StatusNotFound) return } @@ -1540,22 +1545,22 @@ func discardTxnOnErr(ctx context.Context, bus Bus, l *zap.SugaredLogger, txn typ ctx, cancel := context.WithTimeout(ctx, 10*time.Second) if dErr := bus.WalletDiscard(ctx, txn); dErr != nil { - l.Errorf("%w: failed to discard txn: %v", *err, dErr) + l.Errorf("%w: %v, failed to discard txn: %v", *err, errContext, dErr) } cancel() } func isErrHostUnreachable(err error) bool { - return isError(err, os.ErrDeadlineExceeded) || - isError(err, context.DeadlineExceeded) || - isError(err, api.ErrHostOnPrivateNetwork) || - isError(err, errors.New("no route to host")) || - isError(err, errors.New("no such host")) || - isError(err, errors.New("connection refused")) || - isError(err, errors.New("unknown port")) || - isError(err, errors.New("cannot assign requested address")) + return utils.IsErr(err, os.ErrDeadlineExceeded) || + utils.IsErr(err, context.DeadlineExceeded) || + utils.IsErr(err, api.ErrHostOnPrivateNetwork) || + utils.IsErr(err, errors.New("no route to host")) || + utils.IsErr(err, errors.New("no such host")) || + utils.IsErr(err, errors.New("connection refused")) || + utils.IsErr(err, errors.New("unknown port")) || + utils.IsErr(err, errors.New("cannot assign requested address")) } func isErrDuplicateTransactionSet(err error) bool { - return err != nil && strings.Contains(err.Error(), modules.ErrDuplicateTransactionSet.Error()) + return utils.IsErr(err, modules.ErrDuplicateTransactionSet) }