Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker Cleanup #815

Merged
merged 5 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion worker/contract_lock.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func newContractLock(fcid types.FileContractID, lockID uint64, d time.Duration,
return cl
}

func (w *worker) acquireContractLock(ctx context.Context, fcid types.FileContractID, priority int) (_ revisionUnlocker, err error) {
func (w *worker) acquireContractLock(ctx context.Context, fcid types.FileContractID, priority int) (_ *contractLock, err error) {
lockID, err := w.bus.AcquireContract(ctx, fcid, priority, w.contractLockingDuration)
if err != nil {
return nil, err
Expand Down
106 changes: 61 additions & 45 deletions worker/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"math"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -40,10 +41,9 @@ type (
id [8]byte

downloadManager struct {
hm HostManager
mm MemoryManager
hp hostProvider
pss partialSlabStore
slm sectorLostMarker
os ObjectStore
logger *zap.SugaredLogger

maxOverdrive uint64
Expand All @@ -52,26 +52,26 @@ type (
statsOverdrivePct *stats.DataPoints
statsSlabDownloadSpeedBytesPerMS *stats.DataPoints

stopChan chan struct{}
shutdownCtx context.Context

mu sync.Mutex
downloaders map[types.PublicKey]*downloader
lastRecompute time.Time
}

downloader struct {
host hostV3
host Host

statsDownloadSpeedBytesPerMS *stats.DataPoints // keep track of this separately for stats (no decay is applied)
statsSectorDownloadEstimateInMS *stats.DataPoints

signalWorkChan chan struct{}
stopChan chan struct{}
shutdownCtx context.Context

mu sync.Mutex
consecutiveFailures uint64
queue []*sectorDownloadReq
numDownloads uint64
queue []*sectorDownloadReq
}

downloaderStats struct {
Expand All @@ -80,13 +80,8 @@ type (
numDownloads uint64
}

sectorLostMarker interface {
DeleteHostSector(ctx context.Context, hk types.PublicKey, root types.Hash256) error
}

slabDownload struct {
mgr *downloadManager
slm sectorLostMarker

minShards int
offset uint32
Expand Down Expand Up @@ -165,15 +160,14 @@ func (w *worker) initDownloadManager(maxMemory, maxOverdrive uint64, overdriveTi
}

mm := newMemoryManager(logger, maxMemory)
w.downloadManager = newDownloadManager(w, w, mm, w.bus, maxOverdrive, overdriveTimeout, logger)
w.downloadManager = newDownloadManager(w.shutdownCtx, w, mm, w.bus, maxOverdrive, overdriveTimeout, logger)
}

func newDownloadManager(hp hostProvider, pss partialSlabStore, mm MemoryManager, slm sectorLostMarker, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) *downloadManager {
func newDownloadManager(ctx context.Context, hm HostManager, mm MemoryManager, os ObjectStore, maxOverdrive uint64, overdriveTimeout time.Duration, logger *zap.SugaredLogger) *downloadManager {
return &downloadManager{
hp: hp,
hm: hm,
mm: mm,
pss: pss,
slm: slm,
os: os,
logger: logger,

maxOverdrive: maxOverdrive,
Expand All @@ -182,21 +176,21 @@ func newDownloadManager(hp hostProvider, pss partialSlabStore, mm MemoryManager,
statsOverdrivePct: stats.NoDecay(),
statsSlabDownloadSpeedBytesPerMS: stats.NoDecay(),

stopChan: make(chan struct{}),
shutdownCtx: ctx,

downloaders: make(map[types.PublicKey]*downloader),
}
}

func newDownloader(host hostV3) *downloader {
func newDownloader(ctx context.Context, host Host) *downloader {
return &downloader{
host: host,

statsSectorDownloadEstimateInMS: stats.Default(),
statsDownloadSpeedBytesPerMS: stats.NoDecay(),

signalWorkChan: make(chan struct{}, 1),
stopChan: make(chan struct{}),
shutdownCtx: ctx,

queue: make([]*sectorDownloadReq, 0),
}
Expand Down Expand Up @@ -231,7 +225,7 @@ func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o o
if !slabs[i].PartialSlab {
continue
}
data, slab, err := mgr.pss.PartialSlab(ctx, slabs[i].SlabSlice.Key, slabs[i].SlabSlice.Offset, slabs[i].SlabSlice.Length)
data, slab, err := mgr.fetchPartialSlab(ctx, slabs[i].SlabSlice.Key, slabs[i].SlabSlice.Offset, slabs[i].SlabSlice.Length)
if err != nil {
return fmt.Errorf("failed to fetch partial slab data: %w", err)
}
Expand Down Expand Up @@ -288,7 +282,7 @@ func (mgr *downloadManager) DownloadObject(ctx context.Context, w io.Writer, o o
select {
case <-ctx.Done():
return
case <-mgr.stopChan:
case <-mgr.shutdownCtx.Done():
return
default:
}
Expand Down Expand Up @@ -345,7 +339,7 @@ outer:
for {
var resp *slabDownloadResponse
select {
case <-mgr.stopChan:
case <-mgr.shutdownCtx.Done():
return errDownloadManagerStopped
case <-ctx.Done():
return errors.New("download timed out")
Expand Down Expand Up @@ -483,9 +477,8 @@ func (mgr *downloadManager) Stats() downloadManagerStats {
func (mgr *downloadManager) Stop() {
mgr.mu.Lock()
defer mgr.mu.Unlock()
close(mgr.stopChan)
for _, d := range mgr.downloaders {
close(d.stopChan)
d.Stop()
}
}

Expand All @@ -509,6 +502,24 @@ func (mgr *downloadManager) numDownloaders() int {
return len(mgr.downloaders)
}

// fetchPartialSlab fetches the data of a partial slab from the bus. It will
// fall back to ask the bus for the slab metadata in case the slab wasn't found
// in the partial slab buffer.
func (mgr *downloadManager) fetchPartialSlab(ctx context.Context, key object.EncryptionKey, offset, length uint32) ([]byte, *object.Slab, error) {
data, err := mgr.os.FetchPartialSlab(ctx, key, offset, length)
if err != nil && strings.Contains(err.Error(), api.ErrObjectNotFound.Error()) {
// Check if slab was already uploaded.
slab, err := mgr.os.Slab(ctx, key)
if err != nil {
return nil, nil, fmt.Errorf("failed to fetch uploaded partial slab: %v", err)
}
return nil, &slab, nil
} else if err != nil {
return nil, nil, err
}
return data, nil, nil
}

func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
Expand All @@ -523,7 +534,7 @@ func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata)
for hk := range mgr.downloaders {
_, wanted := want[hk]
if !wanted {
close(mgr.downloaders[hk].stopChan)
mgr.downloaders[hk].Stop()
delete(mgr.downloaders, hk)
continue
}
Expand All @@ -534,10 +545,10 @@ func (mgr *downloadManager) refreshDownloaders(contracts []api.ContractMetadata)
// update downloaders
for _, c := range want {
// create a host
host := mgr.hp.newHostV3(c.ID, c.HostKey, c.SiamuxAddr)
downloader := newDownloader(host)
host := mgr.hm.Host(c.HostKey, c.ID, c.SiamuxAddr)
downloader := newDownloader(mgr.shutdownCtx, host)
mgr.downloaders[c.HostKey] = downloader
go downloader.processQueue(mgr.hp)
go downloader.processQueue(mgr.hm)
}
}

Expand All @@ -554,7 +565,6 @@ func (mgr *downloadManager) newSlabDownload(ctx context.Context, dID id, slice o
// create slab download
return &slabDownload{
mgr: mgr,
slm: mgr.slm,

minShards: int(slice.MinShards),
offset: offset,
Expand Down Expand Up @@ -583,6 +593,18 @@ func (mgr *downloadManager) downloadSlab(ctx context.Context, dID id, slice obje
return slab.download(ctx)
}

func (d *downloader) Stop() {
for {
download := d.pop()
if download == nil {
break
}
if !download.done() {
download.fail(errors.New("downloader stopped"))
}
}
}

func (d *downloader) stats() downloaderStats {
d.mu.Lock()
defer d.mu.Unlock()
Expand All @@ -593,15 +615,6 @@ func (d *downloader) stats() downloaderStats {
}
}

func (d *downloader) isStopped() bool {
select {
case <-d.stopChan:
return true
default:
}
return false
}

func (d *downloader) fillBatch() (batch []*sectorDownloadReq) {
for len(batch) < maxConcurrentSectorsPerHost {
if req := d.pop(); req == nil {
Expand Down Expand Up @@ -639,8 +652,11 @@ func (d *downloader) processBatch(batch []*sectorDownloadReq) chan struct{} {
reqsChan := make(chan *sectorDownloadReq)
workerFn := func() {
for req := range reqsChan {
if d.isStopped() {
// check if we need to abort
select {
case <-d.shutdownCtx.Done():
break
default:
}

// update state
Expand Down Expand Up @@ -702,13 +718,13 @@ func (d *downloader) processBatch(batch []*sectorDownloadReq) chan struct{} {
return doneChan
}

func (d *downloader) processQueue(hp hostProvider) {
func (d *downloader) processQueue(hp HostManager) {
outer:
for {
// wait for work
select {
case <-d.signalWorkChan:
case <-d.stopChan:
case <-d.shutdownCtx.Done():
return
}

Expand All @@ -723,7 +739,7 @@ outer:
doneChan := d.processBatch(batch)
for {
select {
case <-d.stopChan:
case <-d.shutdownCtx.Done():
return
case <-doneChan:
continue outer
Expand Down Expand Up @@ -1015,7 +1031,7 @@ func (s *slabDownload) download(ctx context.Context) ([][]byte, bool, error) {
loop:
for s.inflight() > 0 && !done {
select {
case <-s.mgr.stopChan:
case <-s.mgr.shutdownCtx.Done():
return nil, false, errors.New("download stopped")
case <-ctx.Done():
return nil, false, ctx.Err()
Expand Down Expand Up @@ -1049,7 +1065,7 @@ loop:

// handle lost sectors
if isSectorNotFound(resp.err) {
if err := s.slm.DeleteHostSector(ctx, resp.req.hk, resp.req.root); err != nil {
if err := s.mgr.os.DeleteHostSector(ctx, resp.req.hk, resp.req.root); err != nil {
s.mgr.logger.Errorw("failed to mark sector as lost", "hk", resp.req.hk, "root", resp.req.root, zap.Error(err))
} else {
s.mgr.logger.Infow("successfully marked sector as lost", "hk", resp.req.hk, "root", resp.req.root)
Expand Down
Loading