diff --git a/host/contracts/contracts.go b/host/contracts/contracts.go index 209baca8..6c5cce9a 100644 --- a/host/contracts/contracts.go +++ b/host/contracts/contracts.go @@ -148,10 +148,10 @@ type ( once sync.Once done func() // done is called when the updater is closed. - sectors uint64 contractID types.FileContractID sectorActions []SectorChange sectorRoots []types.Hash256 + oldRoots []types.Hash256 } ) @@ -331,7 +331,7 @@ func (cu *ContractUpdater) Commit(revision SignedRevision, usage Usage) error { start := time.Now() // revise the contract - err := cu.store.ReviseContract(revision, usage, cu.sectorActions) + err := cu.store.ReviseContract(revision, cu.oldRoots, usage, cu.sectorActions) if err == nil { // clear the committed sector actions cu.sectorActions = cu.sectorActions[:0] diff --git a/host/contracts/manager.go b/host/contracts/manager.go index 5a8ef635..8b79d747 100644 --- a/host/contracts/manager.go +++ b/host/contracts/manager.go @@ -466,8 +466,8 @@ func (cm *ContractManager) ReviseContract(contractID types.FileContractID) (*Con rootsCache: cm.rootsCache, contractID: contractID, - sectors: uint64(len(roots)), - sectorRoots: roots, + sectorRoots: roots, // roots is already a deep copy + oldRoots: append([]types.Hash256(nil), roots...), done: done, // decrements the threadgroup counter after the updater is closed }, nil diff --git a/host/contracts/manager_test.go b/host/contracts/manager_test.go index 890006b7..4abc3ef3 100644 --- a/host/contracts/manager_test.go +++ b/host/contracts/manager_test.go @@ -1013,7 +1013,7 @@ func TestSectorRoots(t *testing.T) { defer release() // use the database method directly to avoid the sector cache - err = db.ReviseContract(rev, contracts.Usage{}, []contracts.SectorChange{ + err = db.ReviseContract(rev, roots, contracts.Usage{}, []contracts.SectorChange{ {Action: contracts.SectorActionAppend, Root: root}, }) if err != nil { diff --git a/host/contracts/persist.go b/host/contracts/persist.go index 13c2429a..e8bda305 100644 --- a/host/contracts/persist.go +++ b/host/contracts/persist.go @@ -48,7 +48,7 @@ type ( ContractAction(height uint64, contractFn func(types.FileContractID, uint64, string)) error // ReviseContract atomically updates a contract and its associated // sector roots. - ReviseContract(revision SignedRevision, usage Usage, sectorChanges []SectorChange) error + ReviseContract(revision SignedRevision, oldRoots []types.Hash256, usage Usage, sectorChanges []SectorChange) error // UpdateContractState atomically updates the contract manager's state. UpdateContractState(modules.ConsensusChangeID, uint64, func(UpdateStateTransaction) error) error // ExpireContractSectors removes sector roots for any contracts that are diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index 0cdbdeb5..802088ac 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -259,17 +259,8 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac }) } -func contractSectorRoots(tx txn, contractID int64) (uint64, error) { - var index uint64 - err := tx.QueryRow(`SELECT COUNT(*) FROM contract_sector_roots WHERE contract_id=$1`, contractID).Scan(&index) - if errors.Is(err, sql.ErrNoRows) { - return 0, nil - } - return index, err -} - // ReviseContract atomically updates a contract's revision and sectors -func (s *Store) ReviseContract(revision contracts.SignedRevision, usage contracts.Usage, sectorChanges []contracts.SectorChange) error { +func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error { return s.transaction(func(tx txn) error { // revise the contract contractID, err := reviseContract(tx, revision) @@ -286,11 +277,7 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, usage contract } // update the sector roots - sectors, err := contractSectorRoots(tx, contractID) - if err != nil { - return fmt.Errorf("failed to get sector index: %w", err) - } - + sectors := uint64(len(oldRoots)) for _, change := range sectorChanges { switch change.Action { case contracts.SectorActionAppend: diff --git a/persist/sqlite/contracts_test.go b/persist/sqlite/contracts_test.go index dee660ab..dcc5ceaf 100644 --- a/persist/sqlite/contracts_test.go +++ b/persist/sqlite/contracts_test.go @@ -37,7 +37,7 @@ func rootsEqual(a, b []types.Hash256) error { return nil } -func runRevision(db *Store, revision contracts.SignedRevision, changes []contracts.SectorChange) error { +func runRevision(db *Store, revision contracts.SignedRevision, roots []types.Hash256, changes []contracts.SectorChange) error { for _, change := range changes { switch change.Action { // store a sector in the database for the append or update actions @@ -52,7 +52,7 @@ func runRevision(db *Store, revision contracts.SignedRevision, changes []contrac } } - return db.ReviseContract(revision, contracts.Usage{}, changes) + return db.ReviseContract(revision, roots, contracts.Usage{}, changes) } func TestReviseContract(t *testing.T) { @@ -261,6 +261,7 @@ func TestReviseContract(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + oldRoots := append([]types.Hash256(nil), roots...) // update the expected roots for i, change := range test.changes { switch change.Action { @@ -301,7 +302,7 @@ func TestReviseContract(t *testing.T) { } } - if err := runRevision(db, contract, test.changes); err != nil { + if err := runRevision(db, contract, oldRoots, test.changes); err != nil { if test.errors { t.Log("received error:", err) return diff --git a/persist/sqlite/volumes_test.go b/persist/sqlite/volumes_test.go index 005231e5..382ef9e8 100644 --- a/persist/sqlite/volumes_test.go +++ b/persist/sqlite/volumes_test.go @@ -652,7 +652,7 @@ func TestPrune(t *testing.T) { Action: contracts.SectorActionAppend, }) } - err = db.ReviseContract(c, contracts.Usage{}, changes) + err = db.ReviseContract(c, []types.Hash256{}, contracts.Usage{}, changes) if err != nil { t.Fatal(err) } @@ -776,11 +776,12 @@ func TestPrune(t *testing.T) { changes = []contracts.SectorChange{ {Action: contracts.SectorActionTrim, A: uint64(len(contractSectors) / 2)}, } - if err := db.ReviseContract(c, contracts.Usage{}, changes); err != nil { + if err := db.ReviseContract(c, contractSectors, contracts.Usage{}, changes); err != nil { t.Fatal(err) } + contractSectors = contractSectors[:len(contractSectors)/2] - if err := checkConsistency(contractSectors[:len(contractSectors)/2], nil, nil, roots[50:]); err != nil { + if err := checkConsistency(contractSectors, nil, nil, roots[50:]); err != nil { t.Fatal(err) }