diff --git a/autopilot/autopilot.go b/autopilot/autopilot.go index e5ddd8411..f224616c8 100644 --- a/autopilot/autopilot.go +++ b/autopilot/autopilot.go @@ -202,6 +202,13 @@ func (ap *Autopilot) Run() error { var forceScan bool var launchAccountRefillsOnce sync.Once for { + // check for shutdown right before starting a new iteration + select { + case <-ap.shutdownCtx.Done(): + return nil + default: + } + ap.logger.Info("autopilot iteration starting") tickerFired := make(chan struct{}) ap.workers.withWorker(func(w Worker) { @@ -220,7 +227,7 @@ func (ap *Autopilot) Run() error { close(tickerFired) return } - ap.logger.Error("autopilot stopped before consensus was synced") + ap.logger.Info("autopilot stopped before consensus was synced") return } else if blocked { if scanning, _ := ap.s.Status(); !scanning { @@ -234,7 +241,7 @@ func (ap *Autopilot) Run() error { close(tickerFired) return } - ap.logger.Error("autopilot stopped before it was able to confirm it was configured in the bus") + ap.logger.Info("autopilot stopped before it was able to confirm it was configured in the bus") return } diff --git a/cmd/renterd/main.go b/cmd/renterd/main.go index 79d1e31b4..7413a8572 100644 --- a/cmd/renterd/main.go +++ b/cmd/renterd/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "flag" "fmt" "log" @@ -626,25 +627,48 @@ func main() { logger.Fatal("Fatal autopilot error: " + err.Error()) } - // Give each service a fraction of the total shutdown timeout. One service - // timing out shouldn't prevent the others from attempting a shutdown. - timeout := cfg.ShutdownTimeout / time.Duration(len(shutdownFns)) - shutdown := func(fn func(ctx context.Context) error) error { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - return fn(ctx) - } - - // Shut down the autopilot first, then the rest of the services in reverse order and then + // Define a shutdown function that updates the exit code after shutting down + // a service and logs the outcome exitCode := 0 - for i := len(shutdownFns) - 1; i >= 0; i-- { - if err := shutdown(shutdownFns[i].fn); err != nil { - logger.Sugar().Errorf("Failed to shut down %v: %v", shutdownFns[i].name, err) + shutdown := func(ctx context.Context, fn shutdownFn) { + logger.Sugar().Infof("Shutting down %v...", fn.name) + start := time.Now() + if err := fn.fn(ctx); errors.Is(err, worker.ErrShutdownTimedOut) { + logger.Sugar().Errorf("%v shutdown timed out after %v", fn.name, time.Since(start)) + exitCode = 1 + } else if err != nil { + logger.Sugar().Errorf("%v shutdown failed after %v with err: %v", fn.name, time.Since(start), err) exitCode = 1 } else { - logger.Sugar().Infof("%v shut down successfully", shutdownFns[i].name) + logger.Sugar().Infof("%v shutdown successful after %v", fn.name, time.Since(start)) + } + } + + // Reserve a portion of the shutdown timeout to allow graceful shutdown of + // services after a potential timeout from a prior service that took too + // long to shut down. This way we allow all services to shut down + // gracefully. + reserved := (cfg.ShutdownTimeout / 5).Round(5 * time.Second) + ctx, cancel := context.WithTimeoutCause(context.Background(), cfg.ShutdownTimeout-reserved, worker.ErrShuttingDown) + defer cancel() + + // Shut down the services in reverse order + for i := len(shutdownFns) - 1; i >= 0; i-- { + // use reserve context if necessary + select { + case <-ctx.Done(): + if reserved == 0 { + logger.Sugar().Errorf("%v shutdown skipped, node shutdown exceeded %v", shutdownFns[i].name, cfg.ShutdownTimeout) + exitCode = 1 + continue + } + ctx, _ = context.WithTimeoutCause(context.Background(), reserved, worker.ErrShuttingDown) + reserved = 0 + default: } + shutdown(ctx, shutdownFns[i]) } + logger.Info("Shutdown complete") os.Exit(exitCode) } diff --git a/internal/test/e2e/cluster_test.go b/internal/test/e2e/cluster_test.go index c546937eb..29a6f44ec 100644 --- a/internal/test/e2e/cluster_test.go +++ b/internal/test/e2e/cluster_test.go @@ -27,6 +27,7 @@ import ( "go.sia.tech/renterd/internal/test" "go.sia.tech/renterd/object" "go.sia.tech/renterd/wallet" + "go.sia.tech/renterd/worker" "go.uber.org/zap" "go.uber.org/zap/zapcore" "lukechampine.com/frand" @@ -2362,6 +2363,11 @@ func TestBusRecordedMetrics(t *testing.T) { t.Fatal("expected zero ListSpending") } + // Shut down everything but the bus to avoid an NDF in the following assertion. + cluster.ShutdownS3(context.Background()) + cluster.ShutdownWorker(context.Background()) + cluster.ShutdownAutopilot(context.Background()) + // Prune one of the metrics if err := cluster.Bus.PruneMetrics(context.Background(), api.MetricContract, time.Now()); err != nil { t.Fatal(err) @@ -2448,3 +2454,145 @@ func TestMultipartUploadWrappedByPartialSlabs(t *testing.T) { t.Fatal("unexpected data") } } + +type blockedWriter struct { + blockChan chan struct{} + writingChan chan struct{} + + once sync.Once + + mu sync.Mutex + buffer *bytes.Buffer +} + +func newBlockedWriter() *blockedWriter { + return &blockedWriter{ + buffer: new(bytes.Buffer), + blockChan: make(chan struct{}), + writingChan: make(chan struct{}), + } +} + +func (r *blockedWriter) Bytes() []byte { + r.mu.Lock() + defer r.mu.Unlock() + return r.buffer.Bytes() +} + +func (r *blockedWriter) Write(p []byte) (n int, err error) { + r.once.Do(func() { close(r.writingChan) }) + + <-r.blockChan + + r.mu.Lock() + defer r.mu.Unlock() + return r.buffer.Write(p) +} + +func (r *blockedWriter) waitForWriteStart() { + <-r.writingChan +} + +func (r *blockedWriter) unblock() { + close(r.blockChan) +} + +func TestGracefulShutdown(t *testing.T) { + // create cluster + cluster := newTestCluster(t, testClusterOptions{hosts: test.RedundancySettings.TotalShards}) + defer cluster.Shutdown() + + // convenience variables + b := cluster.Bus + w := cluster.Worker + tt := cluster.tt + + // shut down the autopilot, we don't need it + cluster.ShutdownAutopilot(context.Background()) + + // prepare an object to download + tt.OKAll(w.UploadObject(context.Background(), bytes.NewReader([]byte(t.Name())), api.DefaultBucketName, t.Name(), api.UploadObjectOptions{})) + + // prepare both a reader and a writer that blocks until we unblock them + data := make([]byte, 128) + frand.Read(data) + br := newBlockedReader(data) + bw := newBlockedWriter() + + // upload in separate thread + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if _, err := w.UploadObject(context.Background(), br, api.DefaultBucketName, t.Name()+"blocked", api.UploadObjectOptions{}); err != nil { + t.Error(err) + } + }() + + // download in separate thread + wg.Add(1) + go func() { + defer wg.Done() + if err := w.DownloadObject(context.Background(), bw, api.DefaultBucketName, t.Name(), api.DownloadObjectOptions{}); err != nil { + t.Error(err) + } + }() + + // wait until we are sure both requests are blocked + br.waitForReadStart() + bw.waitForWriteStart() + + // shut the worker down + shutdownDone := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + cluster.ShutdownWorker(ctx) + close(shutdownDone) + cancel() + }() + + // assert shutdown is blocked + select { + case <-shutdownDone: + tt.Fatal("shut down") + case <-time.After(time.Second): + } + + // unblock the download and upload separately, this allows us to check + // for ErrShuttingDown explicitly, unblocking both at the same time + // would shut the worker down along with the server resulting in a race + // between either ErrShuttingDown or 'connection refused' errors + br.unblock() + + // assert uploads after shutdown fail + _, err := w.UploadObject(context.Background(), br, api.DefaultBucketName, t.Name()+"blocked", api.UploadObjectOptions{}) + if err != nil && strings.Contains(err.Error(), worker.ErrShuttingDown.Error()) { + t.Error("new uploads should fail, err:", err) + } + + // assert downloads after shutdown fail + var buf bytes.Buffer + err = w.DownloadObject(context.Background(), &buf, api.DefaultBucketName, t.Name(), api.DownloadObjectOptions{}) + if err != nil && strings.Contains(err.Error(), worker.ErrShuttingDown.Error()) { + t.Error("new downloads should fail, err:", err) + } + + // unblock the upload + bw.unblock() + + // wait for all goroutines to finish + wg.Wait() + + // check the download succeeded + if string(bw.Bytes()) != t.Name() { + t.Fatal("data mismatch") + } + + // check the upload succeeded, we can use the bus for that + _, err = b.Object(context.Background(), api.DefaultBucketName, t.Name()+"blocked", api.GetObjectOptions{}) + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/test/e2e/uploads_test.go b/internal/test/e2e/uploads_test.go index 3f83fd7e4..17fa71833 100644 --- a/internal/test/e2e/uploads_test.go +++ b/internal/test/e2e/uploads_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync" "testing" "time" @@ -17,20 +18,28 @@ import ( type blockedReader struct { remaining int data *bytes.Buffer - readChan chan struct{} - blockChan chan struct{} + + blockChan chan struct{} + readChan chan struct{} + readingChan chan struct{} + + once sync.Once } func newBlockedReader(data []byte) *blockedReader { return &blockedReader{ remaining: len(data), data: bytes.NewBuffer(data), - blockChan: make(chan struct{}), - readChan: make(chan struct{}), + + blockChan: make(chan struct{}), + readChan: make(chan struct{}), + readingChan: make(chan struct{}), } } func (r *blockedReader) Read(buf []byte) (n int, err error) { + r.once.Do(func() { close(r.readingChan) }) + select { case <-r.readChan: <-r.blockChan @@ -44,6 +53,14 @@ func (r *blockedReader) Read(buf []byte) (n int, err error) { return } +func (r *blockedReader) waitForReadStart() { + <-r.readingChan +} + +func (r *blockedReader) unblock() { + close(r.blockChan) +} + func TestUploadingSectorsCache(t *testing.T) { if testing.Short() { t.SkipNow() diff --git a/worker/download.go b/worker/download.go index 3a58bbc98..8747d64f8 100644 --- a/worker/download.go +++ b/worker/download.go @@ -25,8 +25,9 @@ const ( ) var ( - errDownloadNotEnoughHosts = errors.New("not enough hosts available to download the slab") errDownloadCancelled = errors.New("download was cancelled") + errDownloadNotEnoughHosts = errors.New("not enough hosts available to download the slab") + errDownloaderStopped = errors.New("downloader was stopped") ) type ( @@ -42,13 +43,22 @@ type ( statsOverdrivePct *stats.DataPoints statsSlabDownloadSpeedBytesPerMS *stats.DataPoints + // NOTE: this context is being cancelled immediately upon worker + // shutdown, misuse might prevent graceful shutdown of downloads shutdownCtx context.Context + wg sync.WaitGroup mu sync.Mutex downloaders map[types.PublicKey]*downloader lastRecompute time.Time } + downloadManagerStats struct { + avgDownloadSpeedMBPS float64 + avgOverdrivePct float64 + downloaders map[types.PublicKey]downloaderStats + } + downloaderStats struct { avgSpeedMBPS float64 healthy bool @@ -88,6 +98,12 @@ type ( err error } + slabSlice struct { + object.SlabSlice + PartialSlab bool + Data []byte + } + sectorDownloadReq struct { ctx context.Context @@ -119,12 +135,6 @@ type ( object.Sector index int } - - downloadManagerStats struct { - avgDownloadSpeedMBPS float64 - avgOverdrivePct float64 - downloaders map[types.PublicKey]downloaderStats - } ) func (w *worker) initDownloadManager(maxMemory, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) { @@ -156,6 +166,10 @@ func newDownloadManager(ctx context.Context, hm HostManager, mm MemoryManager, o } func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o object.Object, offset, length uint64, contracts []api.ContractMetadata) (err error) { + // register the download in the waitgroup + mgr.wg.Add(1) + defer mgr.wg.Done() + // calculate what slabs we need var ss []slabSlice for _, s := range o.Slabs { @@ -232,8 +246,6 @@ func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o o select { case <-ctx.Done(): return - case <-mgr.shutdownCtx.Done(): - return default: } @@ -261,6 +273,23 @@ func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o o return // interrupted } + // check again if we need to abort + select { + case <-ctx.Done(): + mem.Release() // release memory if we're interrupted + return + case <-mgr.shutdownCtx.Done(): + // if we are in the middle of a download while shutting down we + // attempt to gracefully shut down and complete the download + // first, if we have yet to start the download we want to avoid + // filling up the queues with jobs + if slabIndex == 0 { + mem.Release() // relase memory if we're shutting down + return + } + default: + } + // launch the download wg.Add(1) go func(index int) { @@ -275,7 +304,7 @@ func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o o err: err, }: case <-ctx.Done(): - mem.Release() // relase memory if we're interrupted + mem.Release() // release memory if we're interrupted } }(slabIndex) } @@ -289,8 +318,6 @@ outer: for { var resp *slabDownloadResponse select { - case <-mgr.shutdownCtx.Done(): - return ErrShuttingDown case <-ctx.Done(): return errDownloadCancelled case resp = <-responseChan: @@ -355,6 +382,10 @@ outer: } func (mgr *downloadManager) DownloadSlab(ctx context.Context, slab object.Slab, contracts []api.ContractMetadata) ([][]byte, bool, error) { + // register the download in the waitgroup + mgr.wg.Add(1) + defer mgr.wg.Done() + // refresh the downloaders mgr.refreshDownloaders(contracts) @@ -421,12 +452,27 @@ func (mgr *downloadManager) Stats() downloadManagerStats { } } -func (mgr *downloadManager) Stop() { +func (mgr *downloadManager) Stop(ctx context.Context) { + // wait on all ongoing downloads to finish + doneChan := make(chan struct{}) + go func() { + mgr.wg.Wait() + close(doneChan) + }() + + // allow the context to interrupt the wait + select { + case <-ctx.Done(): + case <-doneChan: + } + + // stop all downloaders mgr.mu.Lock() defer mgr.mu.Unlock() for _, d := range mgr.downloaders { - d.Stop() + d.Stop(context.Cause(ctx)) } + mgr.downloaders = nil } func (mgr *downloadManager) tryRecomputeStats() { @@ -481,7 +527,7 @@ func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata) for hk := range mgr.downloaders { _, wanted := want[hk] if !wanted { - mgr.downloaders[hk].Stop() + mgr.downloaders[hk].Stop(errDownloaderStopped) delete(mgr.downloaders, hk) continue } @@ -491,11 +537,8 @@ func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata) // update downloaders for _, c := range want { - // create a host - host := mgr.hm.Host(c.HostKey, c.ID, c.SiamuxAddr) - downloader := newDownloader(mgr.shutdownCtx, host) - mgr.downloaders[c.HostKey] = downloader - go downloader.processQueue(mgr.hm) + mgr.downloaders[c.HostKey] = newDownloader(mgr.hm.Host(c.HostKey, c.ID, c.SiamuxAddr)) + go mgr.downloaders[c.HostKey].processQueue(mgr.hm) } } @@ -728,8 +771,6 @@ func (s *slabDownload) download(ctx context.Context) ([][]byte, bool, error) { loop: for s.inflight() > 0 && !done { select { - case <-s.mgr.shutdownCtx.Done(): - return nil, false, errors.New("download stopped") case <-ctx.Done(): return nil, false, ctx.Err() case <-resps.c: @@ -899,12 +940,6 @@ func (mgr *downloadManager) fastest(hosts []types.PublicKey) (fastest *downloade return } -type slabSlice struct { - object.SlabSlice - PartialSlab bool - Data []byte -} - func slabsForDownload(slabs []slabSlice, offset, length uint64) []slabSlice { // declare a helper to cast a uint64 to uint32 with overflow detection. This // could should never produce an overflow. diff --git a/worker/downloader.go b/worker/downloader.go index 24be245fc..eebb6be00 100644 --- a/worker/downloader.go +++ b/worker/downloader.go @@ -2,8 +2,6 @@ package worker import ( "bytes" - "context" - "errors" "sync" "sync/atomic" "time" @@ -18,10 +16,6 @@ const ( maxConcurrentSectorsPerHost = 3 ) -var ( - errDownloaderStopped = errors.New("downloader was stopped") -) - type ( downloader struct { host Host @@ -30,7 +24,6 @@ type ( statsSectorDownloadEstimateInMS *stats.DataPoints signalWorkChan chan struct{} - shutdownCtx context.Context mu sync.Mutex consecutiveFailures uint64 @@ -40,7 +33,7 @@ type ( } ) -func newDownloader(ctx context.Context, host Host) *downloader { +func newDownloader(host Host) *downloader { return &downloader{ host: host, @@ -48,7 +41,6 @@ func newDownloader(ctx context.Context, host Host) *downloader { statsDownloadSpeedBytesPerMS: stats.NoDecay(), signalWorkChan: make(chan struct{}, 1), - shutdownCtx: ctx, queue: make([]*sectorDownloadReq, 0), } @@ -58,7 +50,7 @@ func (d *downloader) PublicKey() types.PublicKey { return d.host.PublicKey() } -func (d *downloader) Stop() { +func (d *downloader) Stop(err error) { d.mu.Lock() d.stopped = true d.mu.Unlock() @@ -69,7 +61,7 @@ func (d *downloader) Stop() { break } if !download.done() { - download.fail(errDownloaderStopped) + download.fail(err) } } } @@ -179,13 +171,6 @@ func (d *downloader) processBatch(batch []*sectorDownloadReq) chan struct{} { reqsChan := make(chan *sectorDownloadReq) workerFn := func() { for req := range reqsChan { - // check if we need to abort - select { - case <-d.shutdownCtx.Done(): - return - default: - } - // update state mu.Lock() if start.IsZero() { @@ -249,11 +234,7 @@ func (d *downloader) processQueue(hp HostManager) { outer: for { // wait for work - select { - case <-d.signalWorkChan: - case <-d.shutdownCtx.Done(): - return - } + <-d.signalWorkChan for { // try fill a batch of requests @@ -264,14 +245,8 @@ outer: // process the batch doneChan := d.processBatch(batch) - for { - select { - case <-d.shutdownCtx.Done(): - return - case <-doneChan: - continue outer - } - } + <-doneChan + continue outer } } } diff --git a/worker/downloader_test.go b/worker/downloader_test.go index 8097b8304..0f58d8a18 100644 --- a/worker/downloader_test.go +++ b/worker/downloader_test.go @@ -12,11 +12,10 @@ func TestDownloaderStopped(t *testing.T) { // convenience variables dm := w.downloadManager - h := hosts[0] dm.refreshDownloaders(w.Contracts()) - dl := w.downloadManager.downloaders[h.PublicKey()] - dl.Stop() + dl := dm.downloaders[hosts[0].PublicKey()] + dl.Stop(ErrShuttingDown) req := sectorDownloadReq{ resps: §orResponses{ diff --git a/worker/host.go b/worker/host.go index fceeaba00..8d582ec71 100644 --- a/worker/host.go +++ b/worker/host.go @@ -318,3 +318,18 @@ func (h *host) preparePriceTableContractPayment(rev *types.FileContractRevision) return &payment, nil } } + +func isSuccessfulInteraction(err error) bool { + // No error always means success. + if err == nil { + return true + } + // List of errors that are considered successful interactions. + if isInsufficientFunds(err) { + return true + } + if isBalanceInsufficient(err) { + return true + } + return false +} diff --git a/worker/interactions.go b/worker/interactions.go deleted file mode 100644 index 2107ae582..000000000 --- a/worker/interactions.go +++ /dev/null @@ -1,27 +0,0 @@ -package worker - -import ( - "go.sia.tech/renterd/hostdb" -) - -type ( - HostInteractionRecorder interface { - RecordHostScan(...hostdb.HostScan) - RecordPriceTableUpdate(...hostdb.PriceTableUpdate) - } -) - -func isSuccessfulInteraction(err error) bool { - // No error always means success. - if err == nil { - return true - } - // List of errors that are considered successful interactions. - if isInsufficientFunds(err) { - return true - } - if isBalanceInsufficient(err) { - return true - } - return false -} diff --git a/worker/spending.go b/worker/spending.go index 87d2ec17d..15d2f9064 100644 --- a/worker/spending.go +++ b/worker/spending.go @@ -123,5 +123,6 @@ func (r *contractSpendingRecorder) flush() { r.contractSpendings = make(map[types.FileContractID]api.ContractSpendingRecord) } } + r.flushTimer = nil } diff --git a/worker/upload.go b/worker/upload.go index d95c9db9e..45ee5023f 100644 --- a/worker/upload.go +++ b/worker/upload.go @@ -51,10 +51,14 @@ type ( statsOverdrivePct *stats.DataPoints statsSlabUploadSpeedBytesPerMS *stats.DataPoints + // NOTE: this context is being cancelled immediately upon worker + // shutdown, misuse might prevent graceful shutdown of uploads shutdownCtx context.Context + wg sync.WaitGroup mu sync.Mutex uploaders []*uploader + stopped bool } // TODO: should become a metric @@ -73,8 +77,6 @@ type ( contractLockPriority int contractLockDuration time.Duration - - shutdownCtx context.Context } slabUpload struct { @@ -152,7 +154,7 @@ func (w *worker) initUploadManager(maxMemory, maxOverdrive uint64, overdriveTime w.uploadManager = newUploadManager(w.shutdownCtx, w, mm, w.bus, w.bus, w.bus, maxOverdrive, overdriveTimeout, w.contractLockingDuration, logger) } -func (w *worker) upload(ctx context.Context, r io.Reader, contracts []api.ContractMetadata, up uploadParameters, opts ...UploadOption) (_ string, err error) { +func (w *worker) upload(ctx context.Context, r io.Reader, contracts []api.ContractMetadata, up uploadParameters, opts ...UploadOption) (string, error) { // apply the options for _, opt := range opts { opt(&up) @@ -164,9 +166,10 @@ func (w *worker) upload(ctx context.Context, r io.Reader, contracts []api.Contra // if mime type is still not known, wrap the reader with a mime reader if up.mimeType == "" { + var err error up.mimeType, r, err = newMimeReader(r) if err != nil { - return + return "", err } } } @@ -362,6 +365,12 @@ func (mgr *uploadManager) newUploader(os ObjectStore, cl ContractLocker, cs Cont } } +func (mgr *uploadManager) isStopped() bool { + mgr.mu.Lock() + defer mgr.mu.Unlock() + return mgr.stopped +} + func (mgr *uploadManager) Stats() uploadManagerStats { mgr.mu.Lock() defer mgr.mu.Unlock() @@ -385,15 +394,35 @@ func (mgr *uploadManager) Stats() uploadManagerStats { } } -func (mgr *uploadManager) Stop() { +func (mgr *uploadManager) Stop(ctx context.Context) { + // wait on all ongoing uploads to finish + doneChan := make(chan struct{}) + go func() { + mgr.wg.Wait() + close(doneChan) + }() + + // allow the context to interrupt the wait + select { + case <-ctx.Done(): + case <-doneChan: + } + + // stop uploaders mgr.mu.Lock() defer mgr.mu.Unlock() for _, u := range mgr.uploaders { - u.Stop(ErrShuttingDown) + u.Stop(context.Cause(ctx)) } + mgr.uploaders = nil + mgr.stopped = true } func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []api.ContractMetadata, up uploadParameters, lockPriority int) (bufferSizeLimitReached bool, eTag string, err error) { + // register the upload + mgr.wg.Add(1) + defer mgr.wg.Done() + // cancel all in-flight requests when the upload is done ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -417,15 +446,7 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a if err := mgr.os.TrackUpload(ctx, upload.id); err != nil { return false, "", fmt.Errorf("failed to track upload '%v', err: %w", upload.id, err) } - - // defer a function that finishes the upload - defer func() { - ctx, cancel := context.WithTimeout(mgr.shutdownCtx, time.Minute) - if err := mgr.os.FinishUpload(ctx, upload.id); err != nil { - mgr.logger.Errorf("failed to mark upload %v as finished: %v", upload.id, err) - } - cancel() - }() + defer mgr.finishUpload(upload.id) // create the response channel respChan := make(chan slabUploadResponse) @@ -443,18 +464,32 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a var slabIndex int for { select { - case <-mgr.shutdownCtx.Done(): - return // interrupted case <-ctx.Done(): return // interrupted default: } + // acquire memory mem := mgr.mm.AcquireMemory(ctx, slabSize) if mem == nil { return // interrupted } + // check again if we need to abort + select { + case <-ctx.Done(): + mem.Release() // release memory if we're interrupted + return + case <-mgr.shutdownCtx.Done(): + // if, on worker shutdown, we have yet to initiate the first + // slab upload we don't bother finishing the upload and return + if slabIndex == 0 { + mem.Release() + return + } + default: + } + // read next slab's data data := make([]byte, slabSizeNoRedundancy) length, err := io.ReadFull(io.LimitReader(cr, int64(slabSizeNoRedundancy)), data) @@ -507,8 +542,6 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a numSlabs := math.MaxInt32 for len(responses) < numSlabs { select { - case <-mgr.shutdownCtx.Done(): - return false, "", ErrShuttingDown case <-ctx.Done(): return false, "", errUploadInterrupted case numSlabs = <-numSlabsChan: @@ -561,6 +594,10 @@ func (mgr *uploadManager) Upload(ctx context.Context, r io.Reader, contracts []a } func (mgr *uploadManager) UploadPackedSlab(ctx context.Context, rs api.RedundancySettings, ps api.PackedSlab, mem Memory, contracts []api.ContractMetadata, bh uint64, lockPriority int) (err error) { + // register the upload + mgr.wg.Add(1) + defer mgr.wg.Done() + // cancel all in-flight requests when the upload is done ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -578,15 +615,7 @@ func (mgr *uploadManager) UploadPackedSlab(ctx context.Context, rs api.Redundanc if err := mgr.os.TrackUpload(ctx, upload.id); err != nil { return fmt.Errorf("failed to track upload '%v', err: %w", upload.id, err) } - - // defer a function that finishes the upload - defer func() { - ctx, cancel := context.WithTimeout(mgr.shutdownCtx, time.Minute) - if err := mgr.os.FinishUpload(ctx, upload.id); err != nil { - mgr.logger.Errorf("failed to mark upload %v as finished: %v", upload.id, err) - } - cancel() - }() + defer mgr.finishUpload(upload.id) // upload the shards sectors, uploadSpeed, overdrivePct, err := upload.uploadShards(ctx, shards, mgr.candidates(upload.allowed), mem, mgr.maxOverdrive, mgr.overdriveTimeout) @@ -609,6 +638,10 @@ func (mgr *uploadManager) UploadPackedSlab(ctx context.Context, rs api.Redundanc } func (mgr *uploadManager) UploadShards(ctx context.Context, s *object.Slab, shardIndices []int, shards [][]byte, contractSet string, contracts []api.ContractMetadata, bh uint64, lockPriority int, mem Memory) (err error) { + // register the upload + mgr.wg.Add(1) + defer mgr.wg.Done() + // cancel all in-flight requests when the upload is done ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -623,15 +656,7 @@ func (mgr *uploadManager) UploadShards(ctx context.Context, s *object.Slab, shar if err := mgr.os.TrackUpload(ctx, upload.id); err != nil { return fmt.Errorf("failed to track upload '%v', err: %w", upload.id, err) } - - // defer a function that finishes the upload - defer func() { - ctx, cancel := context.WithTimeout(mgr.shutdownCtx, time.Minute) - if err := mgr.os.FinishUpload(ctx, upload.id); err != nil { - mgr.logger.Errorf("failed to mark upload %v as finished: %v", upload.id, err) - } - cancel() - }() + defer mgr.finishUpload(upload.id) // upload the shards uploaded, uploadSpeed, overdrivePct, err := upload.uploadShards(ctx, shards, mgr.candidates(upload.allowed), mem, mgr.maxOverdrive, mgr.overdriveTimeout) @@ -686,6 +711,24 @@ func (mgr *uploadManager) candidates(allowed map[types.PublicKey]struct{}) (cand return } +func (mgr *uploadManager) finishUpload(uID api.UploadID) error { + // if the manager is stopped, meaning we are passed the graceful shutdown + // window, we don't finish the upload + if mgr.isStopped() { + return nil + } + + // use a sane context to finish uploads + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := mgr.os.FinishUpload(ctx, uID) + if err != nil { + mgr.logger.Errorf("failed to mark upload %v as finished: %v", uID, err) + } + return err +} + func (mgr *uploadManager) newUpload(ctx context.Context, totalShards int, contracts []api.ContractMetadata, bh uint64, lockPriority int) (*upload, error) { mgr.mu.Lock() defer mgr.mu.Unlock() @@ -710,7 +753,6 @@ func (mgr *uploadManager) newUpload(ctx context.Context, totalShards int, contra allowed: allowed, contractLockDuration: mgr.contractLockDuration, contractLockPriority: lockPriority, - shutdownCtx: mgr.shutdownCtx, }, nil } @@ -889,8 +931,6 @@ func (u *upload) uploadShards(ctx context.Context, shards [][]byte, candidates [ loop: for slab.numInflight > 0 && !done { select { - case <-u.shutdownCtx.Done(): - return nil, 0, 0, errors.New("upload stopped") case <-ctx.Done(): return nil, 0, 0, ctx.Err() case resp := <-respChan: diff --git a/worker/uploader.go b/worker/uploader.go index 28b04033d..e212848ab 100644 --- a/worker/uploader.go +++ b/worker/uploader.go @@ -31,10 +31,14 @@ type ( hm HostManager logger *zap.SugaredLogger - hk types.PublicKey - siamuxAddr string + hk types.PublicKey + siamuxAddr string + signalNewUpload chan struct{} - shutdownCtx context.Context + + // NOTE: this context is being cancelled immediately upon worker + // shutdown, misuse might prevent graceful shutdown of uploads + shutdownCtx context.Context mu sync.Mutex endHeight uint64 @@ -84,20 +88,9 @@ func (u *uploader) Start() { outer: for { // wait for work - select { - case <-u.signalNewUpload: - case <-u.shutdownCtx.Done(): - return - } + <-u.signalNewUpload for { - // check if we are stopped - select { - case <-u.shutdownCtx.Done(): - return - default: - } - // pop the next upload req req := u.pop() if req == nil { @@ -211,7 +204,8 @@ func (u *uploader) execute(req *sectorUploadReq) (time.Duration, error) { // defer the release lock := newContractLock(u.shutdownCtx, fcid, lockID, req.contractLockDuration, u.cl, u.logger) defer func() { - ctx, cancel := context.WithTimeout(u.shutdownCtx, 10*time.Second) + // use a sane context to release contract locks + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) lock.Release(ctx) cancel() }() diff --git a/worker/worker.go b/worker/worker.go index b335a5f6c..6a6ffb927 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -52,7 +52,8 @@ const ( ) var ( - ErrShuttingDown = errors.New("worker is shutting down") + ErrShutdownTimedOut = errors.New("worker shutdown timed out") + ErrShuttingDown = errors.New("worker is shutting down") ) // re-export the client @@ -217,7 +218,7 @@ type worker struct { contractLockingDuration time.Duration shutdownCtx context.Context - shutdownCtxCancel context.CancelFunc + shutdownCtxCancel context.CancelCauseFunc logger *zap.SugaredLogger } @@ -243,11 +244,14 @@ func (w *worker) withRevision(ctx context.Context, fetchTimeout time.Duration, f } func (w *worker) registerAlert(a alerts.Alert) { - ctx, cancel := context.WithTimeout(w.shutdownCtx, time.Minute) - if err := w.alerts.RegisterAlert(ctx, a); err != nil { + // apply a sane timeout + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := w.alerts.RegisterAlert(ctx, a) + if err != nil { w.logger.Errorf("failed to register alert, err: %v", err) } - cancel() } func (w *worker) rhpScanHandler(jc jape.Context) { @@ -1317,7 +1321,7 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush } l = l.Named("worker").Named(id) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) w := &worker{ alerts: alerts.WithOrigin(b, fmt.Sprintf("worker.%s", id)), allowPrivateIPs: allowPrivateIPs, @@ -1345,7 +1349,7 @@ func New(masterKey [32]byte, id string, b Bus, contractLockingDuration, busFlush // Handler returns an HTTP handler that serves the worker API. func (w *worker) Handler() http.Handler { - return jape.Mux(map[string]jape.Handler{ + return jape.Mux(shutdownMiddleware(w.shutdownCtx, map[string]jape.Handler{ "GET /account/:hostkey": w.accountHandlerGET, "GET /id": w.idHandlerGET, @@ -1373,21 +1377,28 @@ func (w *worker) Handler() http.Handler { "PUT /multipart/*path": w.multipartUploadHandlerPUT, "GET /state": w.stateHandlerGET, - }) + })) } // Shutdown shuts down the worker. func (w *worker) Shutdown(ctx context.Context) error { // cancel shutdown context - w.shutdownCtxCancel() + w.shutdownCtxCancel(ErrShuttingDown) // stop uploads and downloads - w.downloadManager.Stop() - w.uploadManager.Stop() + w.downloadManager.Stop(ctx) + w.uploadManager.Stop(ctx) // stop recorders w.contractSpendingRecorder.Stop(ctx) - return nil + + // return error on timeout + select { + case <-ctx.Done(): + return ErrShutdownTimedOut + default: + return nil + } } func (w *worker) scanHost(ctx context.Context, hostKey types.PublicKey, hostIP string) (rhpv2.HostSettings, rhpv3.HostPriceTable, time.Duration, error) { @@ -1482,6 +1493,26 @@ func (w *worker) scanHost(ctx context.Context, hostKey types.PublicKey, hostIP s return settings, pt, duration, err } +func shutdownMiddleware(shutdownCtx context.Context, routes map[string]jape.Handler) map[string]jape.Handler { + for route, handler := range routes { + routes[route] = jape.Adapt(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // if the shutdown context has been cancelled, return a 503 to + // ensure we don't process any new incoming request during + // graceful shutdown + select { + case <-shutdownCtx.Done(): + http.Error(w, ErrShuttingDown.Error(), http.StatusServiceUnavailable) + return + default: + h.ServeHTTP(w, r) + } + }) + })(handler) + } + return routes +} + func discardTxnOnErr(ctx context.Context, bus Bus, l *zap.SugaredLogger, txn types.Transaction, errContext string, err *error) { if *err == nil { return