From cb287a1c6bc6c35d080af766133f7c25b12c6e5e Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Sun, 4 Feb 2024 15:51:54 -0800 Subject: [PATCH 1/3] sqlite: verify sector change consistency --- persist/sqlite/contracts.go | 197 ++++++++++++++++++++----------- persist/sqlite/contracts_test.go | 20 +--- persist/sqlite/sectors.go | 37 +++--- persist/sqlite/store.go | 3 +- 4 files changed, 148 insertions(+), 109 deletions(-) diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index acccfffd..1ccd6820 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -35,6 +35,7 @@ type ( contractSectorRootRef struct { dbID int64 sectorID int64 + root types.Hash256 } ) @@ -260,7 +261,7 @@ func (s *Store) RenewContract(renewal contracts.SignedRevision, clearing contrac } // ReviseContract atomically updates a contract's revision and sectors -func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error { +func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types.Hash256, usage contracts.Usage, sectorChanges []contracts.SectorChange) error { return s.transaction(func(tx txn) error { // revise the contract contractID, err := reviseContract(tx, revision) @@ -277,7 +278,8 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []typ } // update the sector roots - sectors := uint64(len(oldRoots)) + sectors := uint64(len(roots)) + roots := append([]types.Hash256(nil), roots...) for _, change := range sectorChanges { switch change.Action { case contracts.SectorActionAppend: @@ -285,23 +287,48 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, oldRoots []typ return fmt.Errorf("failed to append sector: %w", err) } sectors++ + roots = append(roots, change.Root) case contracts.SectorActionTrim: if sectors < change.A { return fmt.Errorf("cannot trim %v sectors from contract with %v sectors", change.A, sectors) } - if err := trimSectors(tx, contractID, change.A, s.log); err != nil { + trimmed, err := trimSectors(tx, contractID, change.A, s.log) + if err != nil { return fmt.Errorf("failed to trim sectors: %w", err) } sectors -= change.A + removed := roots[len(roots)-int(change.A):] + for _, root := range removed { + if !trimmed[root] { + return fmt.Errorf("inconsistent sector trim: expected %s to be trimmed", root) + } + } + roots = roots[:len(roots)-int(change.A)] case contracts.SectorActionUpdate: - if err := updateSector(tx, contractID, change.Root, change.A); err != nil { + oldRoot, err := updateSector(tx, contractID, change.Root, change.A) + if err != nil { return fmt.Errorf("failed to update sector: %w", err) + } else if roots[change.A] != oldRoot { + return fmt.Errorf("inconsistent sector update (%d): expected old sector %s, got %s", change.A, roots[change.A], oldRoot) } + roots[change.A] = change.Root case contracts.SectorActionSwap: - if err := swapSectors(tx, contractID, change.A, change.B); err != nil { + if change.A > change.B { + change.A, change.B = change.B, change.A + } + + swapped, err := swapSectors(tx, contractID, change.A, change.B) + if err != nil { return fmt.Errorf("failed to swap sectors: %w", err) } + oldA, oldB := roots[change.A], roots[change.B] + if swapped[0] != oldA { + return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldA, swapped[0]) + } else if swapped[1] != oldB { + return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldB, swapped[1]) + } + roots[change.A], roots[change.B] = roots[change.B], roots[change.A] } } return nil @@ -545,83 +572,92 @@ func appendSector(tx txn, contractID int64, root types.Hash256, index uint64) er return nil } -func updateSector(tx txn, contractID int64, root types.Hash256, index uint64) error { - var oldSectorID int64 - if err := tx.QueryRow(`SELECT sector_id FROM contract_sector_roots WHERE contract_id=$1 AND root_index=$2`, contractID, index).Scan(&oldSectorID); err != nil { - return fmt.Errorf("failed to get old sector id: %w", err) +func updateSector(tx txn, contractID int64, root types.Hash256, index uint64) (types.Hash256, error) { + row := tx.QueryRow(`SELECT csr.id, csr.sector_id, ss.sector_root +FROM contract_sector_roots csr +INNER JOIN stored_sectors ss ON (csr.sector_id = ss.id) +WHERE contract_id=$1 AND root_index=$2`, contractID, index) + ref, err := scanContractSectorRootRef(row) + if err != nil { + return types.Hash256{}, fmt.Errorf("failed to get old sector id: %w", err) } - const query = `WITH sector AS ( - SELECT id FROM stored_sectors WHERE sector_root=$1 -) -UPDATE contract_sector_roots -SET sector_id=sector.id -FROM sector -WHERE contract_id=$2 AND root_index=$3 -RETURNING sector_id;` + // update the sector ID var newSectorID int64 - err := tx.QueryRow(query, sqlHash256(root), contractID, index).Scan(&newSectorID) + err = tx.QueryRow(`WITH sector AS ( + SELECT id FROM stored_sectors WHERE sector_root=$1 + ) + UPDATE contract_sector_roots + SET sector_id=sector.id + FROM sector + WHERE contract_sector_roots.id=$2 + RETURNING sector_id;`, sqlHash256(root), ref.dbID).Scan(&newSectorID) if err != nil { - return err - } else if err := pruneSectorRef(tx, oldSectorID); err != nil { - return fmt.Errorf("failed to prune sector ref: %w", err) + return types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) } - return nil + // prune the old sector ID + if _, err := pruneSectorRef(tx, ref.sectorID); err != nil { + return types.Hash256{}, fmt.Errorf("failed to prune old sector: %w", err) + } + return ref.root, nil } -func swapSectors(tx txn, contractID int64, i, j uint64) error { +func swapSectors(tx txn, contractID int64, i, j uint64) ([2]types.Hash256, error) { if i == j { - return nil + return [2]types.Hash256{}, nil } var records []contractSectorRootRef - rows, err := tx.Query(`SELECT id, sector_id FROM contract_sector_roots WHERE contract_id=$1 AND root_index IN ($2, $3);`, contractID, i, j) + rows, err := tx.Query(`SELECT csr.id, csr.sector_id, ss.sector_root +FROM contract_sector_roots csr +INNER JOIN stored_sectors ss ON (ss.id = csr.sector_id) +WHERE contract_id=$1 AND root_index IN ($2, $3) +ORDER BY root_index ASC;`, contractID, i, j) if err != nil { - return fmt.Errorf("failed to query sector IDs: %w", err) + return [2]types.Hash256{}, fmt.Errorf("failed to query sector IDs: %w", err) } defer rows.Close() for rows.Next() { - var record contractSectorRootRef - if err := rows.Scan(&record.dbID, &record.sectorID); err != nil { - return fmt.Errorf("failed to scan sector ID: %w", err) + ref, err := scanContractSectorRootRef(rows) + if err != nil { + return [2]types.Hash256{}, fmt.Errorf("failed to scan sector ref: %w", err) } - records = append(records, record) + records = append(records, ref) } if len(records) != 2 { - return errors.New("failed to find both sectors") + return [2]types.Hash256{}, errors.New("failed to find both sectors") } - stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2`) + stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2 RETURNING sector_id;`) if err != nil { - return fmt.Errorf("failed to prepare update statement: %w", err) + return [2]types.Hash256{}, fmt.Errorf("failed to prepare update statement: %w", err) } defer stmt.Close() - res, err := stmt.Exec(records[1].sectorID, records[0].dbID) + var newSectorID int64 + err = stmt.QueryRow(records[1].sectorID, records[0].dbID).Scan(&newSectorID) if err != nil { - return fmt.Errorf("failed to update sector ID: %w", err) - } else if rows, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } else if rows != 1 { - return fmt.Errorf("expected 1 row affected, got %v", rows) + return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + } else if newSectorID != records[1].sectorID { + return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } - res, err = stmt.Exec(records[0].sectorID, records[1].dbID) + err = stmt.QueryRow(records[0].sectorID, records[1].dbID).Scan(&newSectorID) if err != nil { - return fmt.Errorf("failed to update sector ID: %w", err) - } else if rows, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } else if rows != 1 { - return fmt.Errorf("expected 1 row affected, got %v", rows) + return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + } else if newSectorID != records[0].sectorID { + return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } - return nil + return [2]types.Hash256{records[0].root, records[1].root}, nil } // lastContractSectors returns the last n sector IDs for a contract. func lastContractSectors(tx txn, contractID int64, n uint64) (roots []contractSectorRootRef, err error) { - const query = `SELECT id, sector_id FROM contract_sector_roots WHERE contract_id=$1 ORDER BY root_index DESC LIMIT $2;` + const query = `SELECT csr.id, csr.sector_id, ss.sector_root FROM contract_sector_roots csr +INNER JOIN stored_sectors ss ON (csr.sector_id=ss.id) +WHERE contract_id=$1 ORDER BY root_index DESC LIMIT $2;` rows, err := tx.Query(query, contractID, n) if err != nil { return nil, err @@ -629,9 +665,9 @@ func lastContractSectors(tx txn, contractID int64, n uint64) (roots []contractSe defer rows.Close() for rows.Next() { - var ref contractSectorRootRef - if err := rows.Scan(&ref.dbID, &ref.sectorID); err != nil { - return nil, err + ref, err := scanContractSectorRootRef(rows) + if err != nil { + return nil, fmt.Errorf("failed to scan sector ref: %w", err) } roots = append(roots, ref) } @@ -647,14 +683,26 @@ func deleteContractSectors(tx txn, refs []contractSectorRootRef) (int, error) { } // delete the sector roots - query := `DELETE FROM contract_sector_roots WHERE id IN (` + queryPlaceHolders(len(rootIDs)) + `);` - res, err := tx.Exec(query, queryArgs(rootIDs)...) + query := `DELETE FROM contract_sector_roots WHERE id IN (` + queryPlaceHolders(len(rootIDs)) + `) RETURNING id;` + rows, err := tx.Query(query, queryArgs(rootIDs)...) if err != nil { return 0, fmt.Errorf("failed to delete sectors: %w", err) - } else if rows, err := res.RowsAffected(); err != nil { - return 0, fmt.Errorf("failed to get rows affected: %w", err) - } else if rows != int64(len(refs)) { - return 0, fmt.Errorf("failed to delete all sectors: %w", err) + } + deleted := make(map[int64]bool) + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return 0, fmt.Errorf("failed to scan deleted sector: %w", err) + } + deleted[id] = true + } + if len(deleted) != len(rootIDs) { + return 0, errors.New("failed to delete all sectors") + } + for _, rootID := range rootIDs { + if !deleted[rootID] { + return 0, errors.New("failed to delete all sectors") + } } // decrement the contract metrics @@ -665,25 +713,30 @@ func deleteContractSectors(tx txn, refs []contractSectorRootRef) (int, error) { // attempt to prune the deleted sectors var pruned int for _, ref := range refs { - if err := pruneSectorRef(tx, ref.sectorID); errors.Is(err, errSectorHasRefs) { - continue - } else if err != nil { + deleted, err := pruneSectorRef(tx, ref.sectorID) + if err != nil { return 0, fmt.Errorf("failed to prune sector ref: %w", err) + } else if deleted { + pruned++ } - pruned++ } return pruned, nil } // trimSectors deletes the last n sector roots for a contract. -func trimSectors(tx txn, contractID int64, n uint64, log *zap.Logger) error { +func trimSectors(tx txn, contractID int64, n uint64, log *zap.Logger) (map[types.Hash256]bool, error) { refs, err := lastContractSectors(tx, contractID, n) if err != nil { - return fmt.Errorf("failed to get sector IDs: %w", err) + return nil, fmt.Errorf("failed to get sector IDs: %w", err) + } else if _, err = deleteContractSectors(tx, refs); err != nil { + return nil, fmt.Errorf("failed to delete sectors: %w", err) } - _, err = deleteContractSectors(tx, refs) - return err + roots := make(map[types.Hash256]bool) + for _, ref := range refs { + roots[ref.root] = true + } + return roots, nil } // clearContract clears a contract and returns its ID @@ -1159,8 +1212,18 @@ func setContractStatus(tx txn, id types.FileContractID, status contracts.Contrac return nil } +func scanContractSectorRef(s scanner) (ref contractSectorRef, err error) { + err = s.Scan(&ref.ID, (*sqlHash256)(&ref.ContractID), &ref.SectorID) + return +} + +func scanContractSectorRootRef(s scanner) (ref contractSectorRootRef, err error) { + err = s.Scan(&ref.dbID, &ref.sectorID, (*sqlHash256)(&ref.root)) + return +} + func expiredContractSectors(tx txn, height uint64, batchSize int64) (sectors []contractSectorRef, _ error) { - const query = `SELECT csr.id, c.contract_id, csr.sector_id FROM contract_sector_roots csr + const query = `SELECT csr.id, c.contract_id, csr.sector_id FROM contract_sector_roots csr INNER JOIN contracts c ON (csr.contract_id=c.id) -- past proof window or not confirmed and past the rebroadcast height WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3;` @@ -1170,8 +1233,8 @@ WHERE c.window_end < $1 OR c.contract_status=$2 LIMIT $3;` } defer rows.Close() for rows.Next() { - var ref contractSectorRef - if err := rows.Scan(&ref.ID, (*sqlHash256)(&ref.ContractID), &ref.SectorID); err != nil { + ref, err := scanContractSectorRef(rows) + if err != nil { return nil, fmt.Errorf("failed to scan expired contract: %w", err) } sectors = append(sectors, ref) diff --git a/persist/sqlite/contracts_test.go b/persist/sqlite/contracts_test.go index dcc5ceaf..ee35222d 100644 --- a/persist/sqlite/contracts_test.go +++ b/persist/sqlite/contracts_test.go @@ -37,24 +37,6 @@ func rootsEqual(a, b []types.Hash256) error { return nil } -func runRevision(db *Store, revision contracts.SignedRevision, roots []types.Hash256, changes []contracts.SectorChange) error { - for _, change := range changes { - switch change.Action { - // store a sector in the database for the append or update actions - case contracts.SectorActionAppend, contracts.SectorActionUpdate: - root := frand.Entropy256() - release, err := db.StoreSector(root, func(loc storage.SectorLocation, exists bool) error { return nil }) - if err != nil { - return fmt.Errorf("failed to store sector: %w", err) - } - defer release() - change.Root = root - } - } - - return db.ReviseContract(revision, roots, contracts.Usage{}, changes) -} - func TestReviseContract(t *testing.T) { log := zaptest.NewLogger(t) db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) @@ -302,7 +284,7 @@ func TestReviseContract(t *testing.T) { } } - if err := runRevision(db, contract, oldRoots, test.changes); err != nil { + if err := db.ReviseContract(contract, oldRoots, contracts.Usage{}, test.changes); err != nil { if test.errors { t.Log("received error:", err) return diff --git a/persist/sqlite/sectors.go b/persist/sqlite/sectors.go index ebf4e552..ee6f14e0 100644 --- a/persist/sqlite/sectors.go +++ b/persist/sqlite/sectors.go @@ -11,8 +11,6 @@ import ( "go.uber.org/zap" ) -var errSectorHasRefs = errors.New("sector has references") - type tempSectorRef struct { ID int64 SectorID int64 @@ -48,13 +46,12 @@ func (s *Store) batchExpireTempSectors(height uint64) (refs []tempSectorRef, rec } for _, ref := range refs { - err := pruneSectorRef(tx, ref.SectorID) - if errors.Is(err, errSectorHasRefs) { - continue - } else if err != nil { + deleted, err := pruneSectorRef(tx, ref.SectorID) + if err != nil { return fmt.Errorf("failed to prune sector: %w", err) + } else if deleted { + reclaimed++ } - reclaimed++ } return nil }) @@ -264,40 +261,40 @@ func clearVolumeSector(tx txn, id int64) error { return nil } -func pruneSectorRef(tx txn, id int64) error { +func pruneSectorRef(tx txn, id int64) (bool, error) { var hasReference bool // check if the sector is referenced by a contract err := tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM contract_sector_roots WHERE sector_id=$1)`, id).Scan(&hasReference) if err != nil { - return fmt.Errorf("failed to check contract references: %w", err) + return false, fmt.Errorf("failed to check contract references: %w", err) } else if hasReference { - return fmt.Errorf("sector referenced by contract: %w", errSectorHasRefs) + return false, nil } // check if the sector is referenced by temp storage err = tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM temp_storage_sector_roots WHERE sector_id=$1)`, id).Scan(&hasReference) if err != nil { - return fmt.Errorf("failed to check temp references: %w", err) + return false, fmt.Errorf("failed to check temp references: %w", err) } else if hasReference { - return fmt.Errorf("sector referenced by temp storage: %w", errSectorHasRefs) + return false, nil } // check if the sector is locked err = tx.QueryRow(`SELECT EXISTS(SELECT 1 FROM locked_sectors WHERE sector_id=$1)`, id).Scan(&hasReference) if err != nil { - return fmt.Errorf("failed to check lock references: %w", err) + return false, fmt.Errorf("failed to check lock references: %w", err) } else if hasReference { - return fmt.Errorf("sector locked: %w", errSectorHasRefs) + return false, nil } // clear the volume sector reference if err = clearVolumeSector(tx, id); err != nil { - return fmt.Errorf("failed to clear volume sector: %w", err) + return false, fmt.Errorf("failed to clear volume sector: %w", err) } // delete the sector if _, err = tx.Exec(`DELETE FROM stored_sectors WHERE id=$1`, id); err != nil { - return fmt.Errorf("failed to delete sector: %w", err) + return false, fmt.Errorf("failed to delete sector: %w", err) } - return nil + return true, nil } func expiredTempSectors(tx txn, height uint64, limit int) (sectors []tempSectorRef, _ error) { @@ -362,10 +359,8 @@ func unlockSector(txn txn, lockIDs ...int64) error { } for _, sectorID := range sectorIDs { - err := pruneSectorRef(txn, sectorID) - if errors.Is(err, errSectorHasRefs) { - continue - } else if err != nil { + _, err := pruneSectorRef(txn, sectorID) + if err != nil { return fmt.Errorf("failed to prune sector: %w", err) } } diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 378e4220..6cf051d4 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -3,7 +3,6 @@ package sqlite import ( "database/sql" "encoding/hex" - "errors" "fmt" "math" "strings" @@ -171,7 +170,7 @@ func clearLockedSectors(tx txn) error { } for _, sectorID := range sectorIDs { - if err := pruneSectorRef(tx, sectorID); err != nil && !errors.Is(err, errSectorHasRefs) { + if _, err := pruneSectorRef(tx, sectorID); err != nil { return fmt.Errorf("failed to prune sector %d: %w", sectorID, err) } } From ac9ada1898644dde08d2aa0bace7e10ee5536c34 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 6 Feb 2024 14:22:57 -0800 Subject: [PATCH 2/3] sqlite: use map for swap --- persist/sqlite/contracts.go | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index 1ccd6820..8599b132 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -323,10 +323,10 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types. return fmt.Errorf("failed to swap sectors: %w", err) } oldA, oldB := roots[change.A], roots[change.B] - if swapped[0] != oldA { - return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldA, swapped[0]) - } else if swapped[1] != oldB { - return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldB, swapped[1]) + for root := range swapped { + if root != oldA && root != oldB { + return fmt.Errorf("inconsistent sector swap: expected %s or %s, got %s", oldA, oldB, root) + } } roots[change.A], roots[change.B] = roots[change.B], roots[change.A] } @@ -602,9 +602,9 @@ WHERE contract_id=$1 AND root_index=$2`, contractID, index) return ref.root, nil } -func swapSectors(tx txn, contractID int64, i, j uint64) ([2]types.Hash256, error) { +func swapSectors(tx txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) { if i == j { - return [2]types.Hash256{}, nil + return nil, nil } var records []contractSectorRootRef @@ -614,43 +614,46 @@ INNER JOIN stored_sectors ss ON (ss.id = csr.sector_id) WHERE contract_id=$1 AND root_index IN ($2, $3) ORDER BY root_index ASC;`, contractID, i, j) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to query sector IDs: %w", err) + return nil, fmt.Errorf("failed to query sector IDs: %w", err) } defer rows.Close() for rows.Next() { ref, err := scanContractSectorRootRef(rows) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to scan sector ref: %w", err) + return nil, fmt.Errorf("failed to scan sector ref: %w", err) } records = append(records, ref) } if len(records) != 2 { - return [2]types.Hash256{}, errors.New("failed to find both sectors") + return nil, errors.New("failed to find both sectors") } stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2 RETURNING sector_id;`) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to prepare update statement: %w", err) + return nil, fmt.Errorf("failed to prepare update statement: %w", err) } defer stmt.Close() var newSectorID int64 err = stmt.QueryRow(records[1].sectorID, records[0].dbID).Scan(&newSectorID) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + return nil, fmt.Errorf("failed to update sector ID: %w", err) } else if newSectorID != records[1].sectorID { - return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) + return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } err = stmt.QueryRow(records[0].sectorID, records[1].dbID).Scan(&newSectorID) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + return nil, fmt.Errorf("failed to update sector ID: %w", err) } else if newSectorID != records[0].sectorID { - return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) + return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } - return [2]types.Hash256{records[0].root, records[1].root}, nil + return map[types.Hash256]bool{ + records[0].root: true, + records[1].root: true, + }, nil } // lastContractSectors returns the last n sector IDs for a contract. From 46468719a27dd771d46dde79b0a0ab605d1f54f6 Mon Sep 17 00:00:00 2001 From: Nate Maninger Date: Tue, 6 Feb 2024 14:29:55 -0800 Subject: [PATCH 3/3] contracts: prevent root cache update if revise fails --- host/contracts/contracts.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/host/contracts/contracts.go b/host/contracts/contracts.go index 6c5cce9a..5831b62a 100644 --- a/host/contracts/contracts.go +++ b/host/contracts/contracts.go @@ -332,12 +332,14 @@ func (cu *ContractUpdater) Commit(revision SignedRevision, usage Usage) error { start := time.Now() // revise the contract err := cu.store.ReviseContract(revision, cu.oldRoots, usage, cu.sectorActions) - if err == nil { - // clear the committed sector actions - cu.sectorActions = cu.sectorActions[:0] + if err != nil { + return err } + + // clear the committed sector actions + cu.sectorActions = cu.sectorActions[:0] // update the roots cache - cu.rootsCache.Add(revision.Revision.ParentID, cu.sectorRoots[:]) + cu.rootsCache.Add(revision.Revision.ParentID, append([]types.Hash256(nil), cu.sectorRoots...)) cu.log.Debug("contract update committed", zap.String("contractID", revision.Revision.ParentID.String()), zap.Uint64("revision", revision.Revision.RevisionNumber), zap.Duration("elapsed", time.Since(start))) - return err + return nil }