Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle contract renewals in the sectors cache #1106

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions bus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ func (b *bus) contractsPrunableDataHandlerGET(jc jape.Context) {
// adjust the amount of prunable data with the pending uploads, due to
// how we record contract spending a contract's size might already
// include pending sectors
pending := b.uploadingSectors.pending(fcid)
pending := b.uploadingSectors.Pending(fcid)
if pending > size.Prunable {
size.Prunable = 0
} else {
Expand Down Expand Up @@ -1064,7 +1064,7 @@ func (b *bus) contractSizeHandlerGET(jc jape.Context) {
// adjust the amount of prunable data with the pending uploads, due to how
// we record contract spending a contract's size might already include
// pending sectors
pending := b.uploadingSectors.pending(id)
pending := b.uploadingSectors.Pending(id)
if pending > size.Prunable {
size.Prunable = 0
} else {
Expand Down Expand Up @@ -1141,6 +1141,7 @@ func (b *bus) contractIDRenewedHandlerPOST(jc jape.Context) {
if jc.Check("couldn't store contract", err) == nil {
jc.Encode(r)
}
b.uploadingSectors.HandleRenewal(req.Contract.ID(), req.RenewedFrom)
}

func (b *bus) contractIDRootsHandlerGET(jc jape.Context) {
Expand All @@ -1153,7 +1154,7 @@ func (b *bus) contractIDRootsHandlerGET(jc jape.Context) {
if jc.Check("couldn't fetch contract sectors", err) == nil {
jc.Encode(api.ContractRootsResponse{
Roots: roots,
Uploading: b.uploadingSectors.sectors(id),
Uploading: b.uploadingSectors.Sectors(id),
})
}
}
Expand Down Expand Up @@ -1991,7 +1992,7 @@ func (b *bus) stateHandlerGET(jc jape.Context) {
func (b *bus) uploadTrackHandlerPOST(jc jape.Context) {
var id api.UploadID
if jc.DecodeParam("id", &id) == nil {
jc.Check("failed to track upload", b.uploadingSectors.trackUpload(id))
jc.Check("failed to track upload", b.uploadingSectors.StartUpload(id))
}
}

Expand All @@ -2004,13 +2005,13 @@ func (b *bus) uploadAddSectorHandlerPOST(jc jape.Context) {
if jc.Decode(&req) != nil {
return
}
jc.Check("failed to add sector", b.uploadingSectors.addUploadingSector(id, req.ContractID, req.Root))
jc.Check("failed to add sector", b.uploadingSectors.AddSector(id, req.ContractID, req.Root))
}

func (b *bus) uploadFinishedHandlerDELETE(jc jape.Context) {
var id api.UploadID
if jc.DecodeParam("id", &id) == nil {
b.uploadingSectors.finishUpload(id)
b.uploadingSectors.FinishUpload(id)
}
}

Expand Down
97 changes: 57 additions & 40 deletions bus/uploadingsectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,105 @@ const (

type (
uploadingSectorsCache struct {
mu sync.Mutex
uploads map[api.UploadID]*ongoingUpload
mu sync.Mutex
uploads map[api.UploadID]*ongoingUpload
renewedTo map[types.FileContractID]types.FileContractID
}

ongoingUpload struct {
mu sync.Mutex
started time.Time
contractSectors map[types.FileContractID][]types.Hash256
}
)

func newUploadingSectorsCache() *uploadingSectorsCache {
return &uploadingSectorsCache{
uploads: make(map[api.UploadID]*ongoingUpload),
uploads: make(map[api.UploadID]*ongoingUpload),
renewedTo: make(map[types.FileContractID]types.FileContractID),
}
}

func (ou *ongoingUpload) addSector(fcid types.FileContractID, root types.Hash256) {
ou.mu.Lock()
defer ou.mu.Unlock()
ou.contractSectors[fcid] = append(ou.contractSectors[fcid], root)
}

func (ou *ongoingUpload) sectors(fcid types.FileContractID) (roots []types.Hash256) {
ou.mu.Lock()
defer ou.mu.Unlock()
if sectors, exists := ou.contractSectors[fcid]; exists && time.Since(ou.started) < cacheExpiry {
roots = append(roots, sectors...)
}
return
}

func (usc *uploadingSectorsCache) addUploadingSector(uID api.UploadID, fcid types.FileContractID, root types.Hash256) error {
// fetch ongoing upload
func (usc *uploadingSectorsCache) AddSector(uID api.UploadID, fcid types.FileContractID, root types.Hash256) error {
usc.mu.Lock()
ongoing, exists := usc.uploads[uID]
usc.mu.Unlock()
defer usc.mu.Unlock()

// add sector if upload exists
if exists {
ongoing.addSector(fcid, root)
return nil
ongoing, ok := usc.uploads[uID]
if !ok {
return fmt.Errorf("%w; id '%v'", api.ErrUnknownUpload, uID)
}

return fmt.Errorf("%w; id '%v'", api.ErrUnknownUpload, uID)
fcid = usc.latestFCID(fcid)
ongoing.addSector(fcid, root)
return nil
}

func (usc *uploadingSectorsCache) pending(fcid types.FileContractID) (size uint64) {
func (usc *uploadingSectorsCache) FinishUpload(uID api.UploadID) {
usc.mu.Lock()
var uploads []*ongoingUpload
for _, ongoing := range usc.uploads {
uploads = append(uploads, ongoing)
defer usc.mu.Unlock()
delete(usc.uploads, uID)

// prune expired uploads
for uID, ongoing := range usc.uploads {
if time.Since(ongoing.started) > cacheExpiry {
delete(usc.uploads, uID)
}
}
usc.mu.Unlock()

for _, ongoing := range uploads {
size += uint64(len(ongoing.sectors(fcid))) * rhp.SectorSize
// prune renewed to map
for old, new := range usc.renewedTo {
if _, exists := usc.renewedTo[new]; exists {
delete(usc.renewedTo, old)
}
}
return
}

func (usc *uploadingSectorsCache) sectors(fcid types.FileContractID) (roots []types.Hash256) {
func (usc *uploadingSectorsCache) HandleRenewal(fcid, renewedFrom types.FileContractID) {
usc.mu.Lock()
var uploads []*ongoingUpload
for _, ongoing := range usc.uploads {
uploads = append(uploads, ongoing)
defer usc.mu.Unlock()

for _, upload := range usc.uploads {
if _, exists := upload.contractSectors[renewedFrom]; exists {
upload.contractSectors[fcid] = upload.contractSectors[renewedFrom]
upload.contractSectors[renewedFrom] = nil
}
}
usc.mu.Unlock()
usc.renewedTo[renewedFrom] = fcid
}

for _, ongoing := range uploads {
roots = append(roots, ongoing.sectors(fcid)...)
func (usc *uploadingSectorsCache) Pending(fcid types.FileContractID) (size uint64) {
usc.mu.Lock()
defer usc.mu.Unlock()

fcid = usc.latestFCID(fcid)
for _, ongoing := range usc.uploads {
size += uint64(len(ongoing.sectors(fcid))) * rhp.SectorSize
}
return
}

func (usc *uploadingSectorsCache) finishUpload(uID api.UploadID) {
func (usc *uploadingSectorsCache) Sectors(fcid types.FileContractID) (roots []types.Hash256) {
usc.mu.Lock()
defer usc.mu.Unlock()
delete(usc.uploads, uID)

// prune expired uploads
for uID, ongoing := range usc.uploads {
if time.Since(ongoing.started) > cacheExpiry {
delete(usc.uploads, uID)
}
fcid = usc.latestFCID(fcid)
for _, ongoing := range usc.uploads {
roots = append(roots, ongoing.sectors(fcid)...)
}
return
}

func (usc *uploadingSectorsCache) trackUpload(uID api.UploadID) error {
func (usc *uploadingSectorsCache) StartUpload(uID api.UploadID) error {
usc.mu.Lock()
defer usc.mu.Unlock()

Expand All @@ -122,3 +132,10 @@ func (usc *uploadingSectorsCache) trackUpload(uID api.UploadID) error {
}
return nil
}

func (usc *uploadingSectorsCache) latestFCID(fcid types.FileContractID) types.FileContractID {
if latest, ok := usc.renewedTo[fcid]; ok {
return latest
}
return fcid
}
82 changes: 66 additions & 16 deletions bus/uploadingsectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"testing"

rhpv2 "go.sia.tech/core/rhp/v2"
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"lukechampine.com/frand"
Expand All @@ -15,20 +16,24 @@ func TestUploadingSectorsCache(t *testing.T) {
uID1 := newTestUploadID()
uID2 := newTestUploadID()

c.trackUpload(uID1)
c.trackUpload(uID2)
fcid1 := types.FileContractID{1}
fcid2 := types.FileContractID{2}
fcid3 := types.FileContractID{3}

_ = c.addUploadingSector(uID1, types.FileContractID{1}, types.Hash256{1})
_ = c.addUploadingSector(uID1, types.FileContractID{2}, types.Hash256{2})
_ = c.addUploadingSector(uID2, types.FileContractID{2}, types.Hash256{3})
c.StartUpload(uID1)
c.StartUpload(uID2)

if roots1 := c.sectors(types.FileContractID{1}); len(roots1) != 1 || roots1[0] != (types.Hash256{1}) {
_ = c.AddSector(uID1, fcid1, types.Hash256{1})
_ = c.AddSector(uID1, fcid2, types.Hash256{2})
_ = c.AddSector(uID2, fcid2, types.Hash256{3})

if roots1 := c.Sectors(fcid1); len(roots1) != 1 || roots1[0] != (types.Hash256{1}) {
t.Fatal("unexpected cached sectors")
}
if roots2 := c.sectors(types.FileContractID{2}); len(roots2) != 2 {
if roots2 := c.Sectors(fcid2); len(roots2) != 2 {
t.Fatal("unexpected cached sectors", roots2)
}
if roots3 := c.sectors(types.FileContractID{3}); len(roots3) != 0 {
if roots3 := c.Sectors(fcid3); len(roots3) != 0 {
t.Fatal("unexpected cached sectors")
}

Expand All @@ -39,28 +44,73 @@ func TestUploadingSectorsCache(t *testing.T) {
t.Fatal("unexpected")
}

c.finishUpload(uID1)
if roots1 := c.sectors(types.FileContractID{1}); len(roots1) != 0 {
c.FinishUpload(uID1)
if roots1 := c.Sectors(fcid1); len(roots1) != 0 {
t.Fatal("unexpected cached sectors")
}
if roots2 := c.sectors(types.FileContractID{2}); len(roots2) != 1 || roots2[0] != (types.Hash256{3}) {
if roots2 := c.Sectors(fcid2); len(roots2) != 1 || roots2[0] != (types.Hash256{3}) {
t.Fatal("unexpected cached sectors")
}

c.finishUpload(uID2)
if roots2 := c.sectors(types.FileContractID{1}); len(roots2) != 0 {
c.FinishUpload(uID2)
if roots2 := c.Sectors(fcid1); len(roots2) != 0 {
t.Fatal("unexpected cached sectors")
}

if err := c.addUploadingSector(uID1, types.FileContractID{1}, types.Hash256{1}); !errors.Is(err, api.ErrUnknownUpload) {
if err := c.AddSector(uID1, fcid1, types.Hash256{1}); !errors.Is(err, api.ErrUnknownUpload) {
t.Fatal("unexpected error", err)
}
if err := c.trackUpload(uID1); err != nil {
if err := c.StartUpload(uID1); err != nil {
t.Fatal("unexpected error", err)
}
if err := c.trackUpload(uID1); !errors.Is(err, api.ErrUploadAlreadyExists) {
if err := c.StartUpload(uID1); !errors.Is(err, api.ErrUploadAlreadyExists) {
t.Fatal("unexpected error", err)
}

// reset cache
c = newUploadingSectorsCache()

// track upload that uploads across two contracts
c.StartUpload(uID1)
c.AddSector(uID1, fcid1, types.Hash256{1})
c.AddSector(uID1, fcid1, types.Hash256{2})
c.HandleRenewal(fcid2, fcid1)
c.AddSector(uID1, fcid2, types.Hash256{3})
c.AddSector(uID1, fcid2, types.Hash256{4})

// assert pending sizes for both contracts should be 4 sectors
p1 := c.Pending(fcid1)
p2 := c.Pending(fcid2)
if p1 != p2 || p1 != 4*rhpv2.SectorSize {
t.Fatal("unexpected pending size", p1/rhpv2.SectorSize, p2/rhpv2.SectorSize)
}

// assert sectors for both contracts contain 4 sectors
s1 := c.Sectors(fcid1)
s2 := c.Sectors(fcid2)
if len(s1) != 4 || len(s2) != 4 {
t.Fatal("unexpected sectors", len(s1), len(s2))
}

// finish upload
c.FinishUpload(uID1)
s1 = c.Sectors(fcid1)
s2 = c.Sectors(fcid2)
if len(s1) != 0 || len(s2) != 0 {
t.Fatal("unexpected sectors", len(s1), len(s2))
}

// renew the contract
c.HandleRenewal(fcid3, fcid2)

// trigger pruning
c.StartUpload(uID2)
c.FinishUpload(uID2)

// assert renewedTo gets pruned
if len(c.renewedTo) != 1 {
t.Fatal("unexpected", len(c.renewedTo))
}
}

func newTestUploadID() api.UploadID {
Expand Down
Loading