diff --git a/worker/host_test.go b/worker/host_test.go index 88a3bb058a..dcef507b90 100644 --- a/worker/host_test.go +++ b/worker/host_test.go @@ -93,16 +93,26 @@ func (h *testHost) DownloadSector(ctx context.Context, w io.Writer, root types.H } func (h *testHost) UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) error { - h.AddSector(sectorRoot, sector) - if h.uploadErr != nil { - return h.uploadErr - } else if h.uploadDelay > 0 { + // sleep if necessary + if h.uploadDelay > 0 { select { case <-time.After(h.uploadDelay): case <-ctx.Done(): return context.Cause(ctx) } } + + // check for cancellation + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + } + + if h.uploadErr != nil { + return h.uploadErr + } + h.AddSector(sectorRoot, sector) return nil } diff --git a/worker/mocks_test.go b/worker/mocks_test.go index 897d96cdb0..ef4e62c6f4 100644 --- a/worker/mocks_test.go +++ b/worker/mocks_test.go @@ -195,7 +195,7 @@ func newContractStoreMock() *contractStoreMock { } func (*contractStoreMock) RenewedContract(context.Context, types.FileContractID) (api.ContractMetadata, error) { - return api.ContractMetadata{}, nil + return api.ContractMetadata{}, api.ErrContractNotFound } func (*contractStoreMock) Contract(context.Context, types.FileContractID) (api.ContractMetadata, error) { diff --git a/worker/rhpv2.go b/worker/rhpv2.go index 1a6bd3cfd9..e0bc11abb0 100644 --- a/worker/rhpv2.go +++ b/worker/rhpv2.go @@ -14,6 +14,7 @@ import ( rhpv2 "go.sia.tech/core/rhp/v2" "go.sia.tech/core/types" "go.sia.tech/renterd/api" + "go.sia.tech/renterd/internal/utils" "go.sia.tech/siad/build" "go.sia.tech/siad/crypto" "lukechampine.com/frand" @@ -85,9 +86,12 @@ func (hes HostErrorSet) Error() string { return "\n" + strings.Join(strs, "\n") } -func wrapErr(err *error, fnName string) { +func wrapErr(ctx context.Context, fnName string, err *error) { if *err != nil { *err = fmt.Errorf("%s: %w", fnName, *err) + if cause := context.Cause(ctx); cause != nil && !utils.IsErr(*err, cause) { + *err = fmt.Errorf("%w; %w", cause, *err) + } } } @@ -133,7 +137,7 @@ func updateRevisionOutputs(rev *types.FileContractRevision, cost, collateral typ // RPCSettings calls the Settings RPC, returning the host's reported settings. func RPCSettings(ctx context.Context, t *rhpv2.Transport) (settings rhpv2.HostSettings, err error) { - defer wrapErr(&err, "Settings") + defer wrapErr(ctx, "Settings", &err) var resp rhpv2.RPCSettingsResponse if err := t.Call(rhpv2.RPCSettingsID, nil, &resp); err != nil { @@ -147,7 +151,7 @@ func RPCSettings(ctx context.Context, t *rhpv2.Transport) (settings rhpv2.HostSe // RPCFormContract forms a contract with a host. func RPCFormContract(ctx context.Context, t *rhpv2.Transport, renterKey types.PrivateKey, txnSet []types.Transaction) (_ rhpv2.ContractRevision, _ []types.Transaction, err error) { - defer wrapErr(&err, "FormContract") + defer wrapErr(ctx, "FormContract", &err) // strip our signatures before sending parents, txn := txnSet[:len(txnSet)-1], txnSet[len(txnSet)-1] diff --git a/worker/rhpv3.go b/worker/rhpv3.go index dc483c3404..a6950e98d5 100644 --- a/worker/rhpv3.go +++ b/worker/rhpv3.go @@ -623,7 +623,7 @@ type PriceTablePaymentFunc func(pt rhpv3.HostPriceTable) (rhpv3.PaymentMethod, e // RPCPriceTable calls the UpdatePriceTable RPC. func RPCPriceTable(ctx context.Context, t *transportV3, paymentFunc PriceTablePaymentFunc) (_ api.HostPriceTable, err error) { - defer wrapErr(&err, "PriceTable") + defer wrapErr(ctx, "PriceTable", &err) s, err := t.DialStream(ctx) if err != nil { @@ -660,7 +660,7 @@ func RPCPriceTable(ctx context.Context, t *transportV3, paymentFunc PriceTablePa // RPCAccountBalance calls the AccountBalance RPC. func RPCAccountBalance(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (bal types.Currency, err error) { - defer wrapErr(&err, "AccountBalance") + defer wrapErr(ctx, "AccountBalance", &err) s, err := t.DialStream(ctx) if err != nil { return types.ZeroCurrency, err @@ -685,7 +685,7 @@ func RPCAccountBalance(ctx context.Context, t *transportV3, payment rhpv3.Paymen // RPCFundAccount calls the FundAccount RPC. func RPCFundAccount(ctx context.Context, t *transportV3, payment rhpv3.PaymentMethod, account rhpv3.Account, settingsID rhpv3.SettingsID) (err error) { - defer wrapErr(&err, "FundAccount") + defer wrapErr(ctx, "FundAccount", &err) s, err := t.DialStream(ctx) if err != nil { return err @@ -712,7 +712,7 @@ func RPCFundAccount(ctx context.Context, t *transportV3, payment rhpv3.PaymentMe // fetching a pricetable using the fetched revision to pay for it. If // paymentFunc returns 'nil' as payment, the host is not paid. func RPCLatestRevision(ctx context.Context, t *transportV3, contractID types.FileContractID, paymentFunc func(rev *types.FileContractRevision) (rhpv3.HostPriceTable, rhpv3.PaymentMethod, error)) (_ types.FileContractRevision, err error) { - defer wrapErr(&err, "LatestRevision") + defer wrapErr(ctx, "LatestRevision", &err) s, err := t.DialStream(ctx) if err != nil { return types.FileContractRevision{}, err @@ -738,7 +738,7 @@ func RPCLatestRevision(ctx context.Context, t *transportV3, contractID types.Fil // RPCReadSector calls the ExecuteProgram RPC with a ReadSector instruction. func RPCReadSector(ctx context.Context, t *transportV3, w io.Writer, pt rhpv3.HostPriceTable, payment rhpv3.PaymentMethod, offset, length uint32, merkleRoot types.Hash256) (cost, refund types.Currency, err error) { - defer wrapErr(&err, "ReadSector") + defer wrapErr(ctx, "ReadSector", &err) s, err := t.DialStream(ctx) if err != nil { return types.ZeroCurrency, types.ZeroCurrency, err @@ -803,7 +803,7 @@ func RPCReadSector(ctx context.Context, t *transportV3, w io.Writer, pt rhpv3.Ho } func RPCAppendSector(ctx context.Context, t *transportV3, renterKey types.PrivateKey, pt rhpv3.HostPriceTable, rev *types.FileContractRevision, payment rhpv3.PaymentMethod, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte) (cost types.Currency, err error) { - defer wrapErr(&err, "AppendSector") + defer wrapErr(ctx, "AppendSector", &err) // sanity check revision first if rev.RevisionNumber == math.MaxUint64 { @@ -941,7 +941,8 @@ func RPCAppendSector(ctx context.Context, t *transportV3, renterKey types.Privat } func RPCRenew(ctx context.Context, rrr api.RHPRenewRequest, bus Bus, t *transportV3, pt *rhpv3.HostPriceTable, rev types.FileContractRevision, renterKey types.PrivateKey, l *zap.SugaredLogger) (_ rhpv2.ContractRevision, _ []types.Transaction, _ types.Currency, err error) { - defer wrapErr(&err, "RPCRenew") + defer wrapErr(ctx, "RPCRenew", &err) + s, err := t.DialStream(ctx) if err != nil { return rhpv2.ContractRevision{}, nil, types.ZeroCurrency, fmt.Errorf("failed to dial stream: %w", err) diff --git a/worker/upload.go b/worker/upload.go index 4a97099bba..aa76964af3 100644 --- a/worker/upload.go +++ b/worker/upload.go @@ -30,10 +30,11 @@ const ( ) var ( - errContractExpired = errors.New("contract expired") - errNoCandidateUploader = errors.New("no candidate uploader found") - errNotEnoughContracts = errors.New("not enough contracts to support requested redundancy") - errUploadInterrupted = errors.New("upload was interrupted") + errContractExpired = errors.New("contract expired") + errNoCandidateUploader = errors.New("no candidate uploader found") + errNotEnoughContracts = errors.New("not enough contracts to support requested redundancy") + errUploadInterrupted = errors.New("upload was interrupted") + errSectorUploadFinished = errors.New("sector upload already finished") ) type ( @@ -117,7 +118,7 @@ type ( root types.Hash256 ctx context.Context - cancel context.CancelFunc + cancel context.CancelCauseFunc mu sync.Mutex uploaded object.Sector @@ -750,7 +751,7 @@ func (u *upload) newSlabUpload(ctx context.Context, shards [][]byte, uploaders [ wg.Add(1) go func(idx int) { // create the ctx - sCtx, sCancel := context.WithCancel(ctx) + sCtx, sCancel := context.WithCancelCause(ctx) // create the sector // NOTE: we are computing the sector root here and pass it all the @@ -1087,7 +1088,7 @@ func (s *sectorUpload) finish(sector object.Sector) { s.mu.Lock() defer s.mu.Unlock() - s.cancel() + s.cancel(errSectorUploadFinished) s.uploaded = sector s.data = nil } diff --git a/worker/uploader.go b/worker/uploader.go index c264f7c452..94c74c916f 100644 --- a/worker/uploader.go +++ b/worker/uploader.go @@ -116,22 +116,35 @@ outer: } // execute it - elapsed, err := u.execute(req) - - // the uploader's contract got renewed, requeue the request - if errors.Is(err, errMaxRevisionReached) { - if u.tryRefresh(req.sector.ctx) { + start := time.Now() + duration, err := u.execute(req) + if err == nil { + // only track the time it took to upload the sector in the happy case + u.trackSectorUpload(true, duration) + u.trackConsecutiveFailures(true) + } else if errors.Is(err, errMaxRevisionReached) { + // the uploader's contract got renewed, requeue the request + if err := u.tryRefresh(); err == nil { u.enqueue(req) continue outer + } else if !utils.IsErr(err, context.Canceled) { + u.logger.Errorf("failed to refresh the uploader's contract %v, err: %v", u.ContractID(), err) } - } - - // track the error, ignore gracefully closed streams and canceled overdrives - canceledOverdrive := req.done() && req.overdrive && err != nil - if !canceledOverdrive && !isClosedStream(err) { - u.trackSectorUpload(err, elapsed) + u.logger.Debugw("skip tracking sector upload", "total", time.Since(start), "duration", duration, "overdrive", req.overdrive, "err", err) + } else if errors.Is(err, errSectorUploadFinished) && !req.overdrive { + // punish the slow host by tracking a multiple of the total time + // we lost on it, but we only do so if we weren't overdriving, + // also note we are not tracking consecutive failures here + // because we're not sure if we had a successful host + // interaction + u.trackSectorUpload(true, time.Since(start)*10) + } else if !errors.Is(err, errSectorUploadFinished) { + // punish the host for failing the upload + u.trackSectorUpload(false, time.Hour) + u.trackConsecutiveFailures(false) + u.logger.Debugw("penalising host for failing to upload sector", "hk", u.hk, "overdrive", req.overdrive, "err", err) } else { - u.logger.Debugw("not tracking sector upload metric", zap.Error(err)) + u.logger.Debugw("skip tracking sector upload", "total", time.Since(start), "duration", duration, "overdrive", req.overdrive, "err", err) } // send the response @@ -198,6 +211,8 @@ func (u *uploader) estimate() float64 { return numSectors * estimateP90 } +// execute executes the sector upload request, if the upload was successful it +// returns the time it took to upload the sector to the host func (u *uploader) execute(req *sectorUploadReq) (time.Duration, error) { // grab fields u.mu.Lock() @@ -233,19 +248,17 @@ func (u *uploader) execute(req *sectorUploadReq) (time.Duration, error) { // update the bus if err := u.os.AddUploadingSector(ctx, req.uploadID, fcid, req.sector.root); err != nil { - return 0, fmt.Errorf("failed to add uploading sector to contract %v, err: %v", fcid, err) + return 0, fmt.Errorf("failed to add uploading sector to contract %v; %w", fcid, err) } // upload the sector start := time.Now() err = host.UploadSector(ctx, req.sector.root, req.sector.sectorData(), rev) if err != nil { - return 0, fmt.Errorf("failed to upload sector to contract %v, err: %v", fcid, err) + return 0, fmt.Errorf("failed to upload sector to contract %v; %w", fcid, err) } - // calculate elapsed time - elapsed := time.Since(start) - return elapsed, nil + return time.Since(start), nil } func (u *uploader) pop() *sectorUploadReq { @@ -268,21 +281,34 @@ func (u *uploader) signalWork() { } } -func (u *uploader) trackSectorUpload(err error, d time.Duration) { +func (u *uploader) trackConsecutiveFailures(success bool) { u.mu.Lock() defer u.mu.Unlock() - if err != nil { - u.consecutiveFailures++ - u.statsSectorUploadEstimateInMS.Track(float64(time.Hour.Milliseconds())) - } else { - ms := d.Milliseconds() - if ms == 0 { - ms = 1 // avoid division by zero - } + // update consecutive failures + if success { u.consecutiveFailures = 0 - u.statsSectorUploadEstimateInMS.Track(float64(ms)) // duration in ms + } else { + u.consecutiveFailures++ + } +} + +func (u *uploader) trackSectorUpload(success bool, d time.Duration) { + u.mu.Lock() + defer u.mu.Unlock() + + // sanitize input + ms := d.Milliseconds() + if ms == 0 { + ms = 1 // avoid division by zero + } + + // update estimates + if success { + u.statsSectorUploadEstimateInMS.Track(float64(ms)) u.statsSectorUploadSpeedBytesPerMS.Track(float64(rhpv2.SectorSize / ms)) // bytes per ms + } else { + u.statsSectorUploadEstimateInMS.Track(float64(ms)) } } @@ -298,17 +324,18 @@ func (u *uploader) tryRecomputeStats() { u.statsSectorUploadSpeedBytesPerMS.Recompute() } -func (u *uploader) tryRefresh(ctx context.Context) bool { +func (u *uploader) tryRefresh() error { + // use a sane timeout + ctx, cancel := context.WithTimeout(u.shutdownCtx, 30*time.Second) + defer cancel() + // fetch the renewed contract renewed, err := u.cs.RenewedContract(ctx, u.ContractID()) - if utils.IsErr(err, api.ErrContractNotFound) || utils.IsErr(err, context.Canceled) { - return false - } else if err != nil { - u.logger.Errorf("failed to fetch renewed contract %v, err: %v", u.ContractID(), err) - return false + if err != nil { + return err } // renew the uploader with the renewed contract u.Refresh(renewed) - return true + return nil } diff --git a/worker/uploader_test.go b/worker/uploader_test.go index 24653a9c41..b7c795c8cc 100644 --- a/worker/uploader_test.go +++ b/worker/uploader_test.go @@ -1,7 +1,6 @@ package worker import ( - "bytes" "context" "errors" "strings" @@ -9,9 +8,9 @@ import ( "time" "go.sia.tech/core/types" + "go.sia.tech/renterd/stats" "go.uber.org/zap" "go.uber.org/zap/zaptest/observer" - "lukechampine.com/frand" ) func TestUploaderStopped(t *testing.T) { @@ -41,111 +40,103 @@ func TestUploaderStopped(t *testing.T) { } func TestUploaderTrackSectorUpload(t *testing.T) { + // create test worker w := newTestWorker(t) + h := w.AddHosts(1)[0] // convenience variables um := w.uploadManager - rs := testRedundancySettings // create custom logger to capture logs observedZapCore, observedLogs := observer.New(zap.DebugLevel) um.logger = zap.New(observedZapCore).Sugar() - // overdrive immediately after 50ms - um.overdriveTimeout = 50 * time.Millisecond - um.maxOverdrive = uint64(rs.TotalShards) + 1 - - // add hosts and add arificial delay of 150ms - hosts := w.AddHosts(rs.TotalShards) - for _, host := range hosts { - host.uploadDelay = 150 * time.Millisecond - } - - // create test data - data := frand.Bytes(128) - - // create upload params - params := testParameters(t.Name()) - - // upload data - _, _, err := um.Upload(context.Background(), bytes.NewReader(data), w.Contracts(), params, lockingPriorityUpload) - if err != nil { - t.Fatal(err) - } + // create uploader + um.refreshUploaders(w.Contracts(), 1) + ul := um.uploaders[0] + ul.statsSectorUploadEstimateInMS = stats.NoDecay() - // define a helper function to fetch an uploader for given host key - uploaderr := func(hk types.PublicKey) *uploader { + // executeReq is a helper that enqueues a sector upload request and returns + // a response, we don't need to alter the request because we'll configure + // the test host to see whether we properly handle sector uploads + executeReq := func(overdrive bool) sectorUploadResp { t.Helper() - um.refreshUploaders(w.Contracts(), 1) - for _, uploader := range um.uploaders { - if uploader.hk == hk { - return uploader - } - } - t.Fatal("uploader not found") - return nil + responseChan := make(chan sectorUploadResp) + ul.enqueue(§orUploadReq{ + contractLockDuration: time.Second, + contractLockPriority: lockingPriorityUpload, + overdrive: overdrive, + responseChan: responseChan, + sector: §orUpload{ctx: context.Background()}, + }) + return <-responseChan } - // define a helper function to fetch uploader stats - stats := func(u *uploader) (uint64, float64) { - t.Helper() - u.mu.Lock() - defer u.mu.Unlock() - return u.consecutiveFailures, u.statsSectorUploadEstimateInMS.P90() + // assert successful uploads reset consecutive failures and track duration + h.uploadDelay = 100 * time.Millisecond + ul.consecutiveFailures = 1 + res := executeReq(false) + if res.err != nil { + t.Fatal(res.err) + } else if ul.consecutiveFailures != 0 { + t.Fatal("unexpected consecutive failures") + } else if ul.statsSectorUploadEstimateInMS.Len() != 1 { + t.Fatal("unexpected number of data points") + } else if avg := ul.statsSectorUploadEstimateInMS.Average(); !(100 <= avg && avg < 200) { + t.Fatal("unexpected average", avg) } - // assert all uploaders have 0 failures and an estimate that roughly equals - // the upload delay - for _, h := range hosts { - if failures, estimate := stats(uploaderr(h.hk)); failures != 0 { - t.Fatal("unexpected failures", failures) - } else if !(estimate >= 150 && estimate < 300) { - t.Fatal("unexpected estimate", estimate) - } + // assert we refresh the underlying contract when max rev. is reached + h.rev.RevisionNumber = types.MaxRevisionNumber + res = executeReq(false) + if res.err == nil { + t.Fatal("expected error") + } else if logs := observedLogs.TakeAll(); len(logs) == 0 { + t.Fatal("missing log entry") + } else if !strings.Contains(logs[0].Message, "failed to refresh the uploader's contract") { + t.Fatal("ununexpected log line") } - // add a host with a 250ms delay - h := w.AddHost() - h.uploadDelay = 250 * time.Millisecond - - // make sure its estimate is not 0 and thus is not used for the upload, but - // instead it is used for the overdrive - ul := uploaderr(h.hk) - ul.statsSectorUploadEstimateInMS.Track(float64(h.uploadDelay.Milliseconds())) - ul.statsSectorUploadEstimateInMS.Recompute() - if ul.statsSectorUploadEstimateInMS.P90() == 0 { - t.Fatal("unexpected p90") + // assert we punish the host if he was slower than one of the other hosts + // overdriving his sector, also assert we don't track consecutive failures + ul.consecutiveFailures = 1 + h.rev.RevisionNumber = 0 // reset + h.uploadErr = errSectorUploadFinished + res = executeReq(false) + if res.err == nil { + t.Fatal("expected error") + } else if ul.consecutiveFailures != 1 { + t.Fatal("unexpected consecutive failures") + } else if ul.statsSectorUploadEstimateInMS.Len() != 2 { + t.Fatal("unexpected number of data points") + } else if avg := ul.statsSectorUploadEstimateInMS.Average(); !(500 <= avg && avg < 600) { + t.Fatal("unexpected average", avg) } - // upload data - _, _, err = um.Upload(context.Background(), bytes.NewReader(data), w.Contracts(), params, lockingPriorityUpload) - if err != nil { - t.Fatal(err) - } - time.Sleep(h.uploadDelay) - - // assert the new host has 0 failures and that we logged an entry indicating - // we skipped tracking the metric - if failures, _ := stats(uploaderr(h.hk)); failures != 0 { - t.Fatal("unexpected failures", failures) - } else if observedLogs.Filter(func(entry observer.LoggedEntry) bool { - return strings.Contains(entry.Message, "not tracking sector upload metric") - }).Len() == 0 { + // assert we don't punish the host if it itself was overdriving the sector + res = executeReq(true) + if res.err == nil { + t.Fatal("expected error") + } else if ul.consecutiveFailures != 1 { + t.Fatal("unexpected consecutive failures") + } else if ul.statsSectorUploadEstimateInMS.Len() != 2 { + t.Fatal("unexpected number of data points") + } else if logs := observedLogs.TakeAll(); len(logs) == 0 { t.Fatal("missing log entry") + } else if !strings.Contains(logs[0].Message, "skip tracking sector upload") { + t.Fatal("ununexpected log line") } - // upload data again but now have the host return an error + // assert we punish the host if it failed the upload for any other reason h.uploadErr = errors.New("host error") - _, _, err = um.Upload(context.Background(), bytes.NewReader(data), w.Contracts(), params, lockingPriorityUpload) - if err != nil { - t.Fatal(err) - } - - // assert the new host has 1 failure and its estimate includes the penalty - uploaderr(h.hk).statsSectorUploadEstimateInMS.Recompute() - if failures, estimate := stats(uploaderr(h.hk)); failures != 1 { - t.Fatal("unexpected failures", failures) - } else if estimate < float64(time.Minute.Milliseconds()) { - t.Fatal("unexpected estimate", estimate) + res = executeReq(false) + if res.err == nil { + t.Fatal("expected error") + } else if ul.consecutiveFailures != 2 { + t.Fatal("unexpected consecutive failures") + } else if ul.statsSectorUploadEstimateInMS.Len() != 3 { + t.Fatal("unexpected number of data points") + } else if avg := ul.statsSectorUploadEstimateInMS.Average(); avg < 1800 { + t.Fatal("unexpected average", avg) } }