diff --git a/api/contract.go b/api/contract.go index 94f8c998a4..92775b268b 100644 --- a/api/contract.go +++ b/api/contract.go @@ -32,6 +32,8 @@ var ( ErrContractSetNotFound = errors.New("couldn't find contract set") ) +type ContractState string + type ( // A Contract wraps the contract metadata with the latest contract revision. Contract struct { diff --git a/chain/manager.go b/chain/manager.go new file mode 100644 index 0000000000..5614604da5 --- /dev/null +++ b/chain/manager.go @@ -0,0 +1,10 @@ +package chain + +import ( + "go.sia.tech/coreutils/chain" +) + +type ( + Manager = chain.Manager + HostAnnouncement = chain.HostAnnouncement +) diff --git a/chain/subscriber.go b/chain/subscriber.go new file mode 100644 index 0000000000..431abb902d --- /dev/null +++ b/chain/subscriber.go @@ -0,0 +1,44 @@ +package chain + +import ( + "context" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/renterd/api" +) + +type ( + ChainManager interface { + Tip() types.ChainIndex + OnReorg(fn func(types.ChainIndex)) (cancel func()) + UpdatesSince(index types.ChainIndex, max int) (rus []chain.RevertUpdate, aus []chain.ApplyUpdate, err error) + } + + ChainStore interface { + BeginChainUpdateTx() (ChainUpdateTx, error) + ChainIndex() (types.ChainIndex, error) + } + + ChainUpdateTx interface { + Commit() error + Rollback() error + + ContractState(fcid types.FileContractID) (api.ContractState, error) + UpdateChainIndex(index types.ChainIndex) error + UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error + UpdateContractState(fcid types.FileContractID, state api.ContractState) error + UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error + UpdateFailedContracts(blockHeight uint64) error + UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error + } + + ContractStore interface { + AddContractStoreSubscriber(context.Context, ContractStoreSubscriber) (map[types.FileContractID]struct{}, func(), error) + } + + ContractStoreSubscriber interface { + AddContractID(fcid types.FileContractID) + } +) diff --git a/stores/chain.go b/stores/chain.go new file mode 100644 index 0000000000..1b77e5d761 --- /dev/null +++ b/stores/chain.go @@ -0,0 +1,356 @@ +package stores + +import ( + "fmt" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/chain" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +var ( + _ chain.ChainStore = (*SQLStore)(nil) + _ chain.ChainUpdateTx = (*chainUpdateTx)(nil) +) + +// chainUpdateTx implements the ChainUpdateTx interface. +type chainUpdateTx struct { + tx *gorm.DB +} + +// BeginChainUpdateTx starts a transaction and wraps it in a chainUpdateTx. This +// transaction will be used to process a chain update in the subscriber. +func (s *SQLStore) BeginChainUpdateTx() (chain.ChainUpdateTx, error) { + tx := s.db.Begin() + if tx.Error != nil { + return nil, tx.Error + } + return &chainUpdateTx{tx: tx}, nil +} + +// ApplyIndex is called with the chain index that is being applied. Any +// transactions and siacoin elements that were created by the index should be +// added and any siacoin elements that were spent should be removed. +func (u *chainUpdateTx) ApplyIndex(index types.ChainIndex, created, spent []types.SiacoinElement, events []wallet.Event) error { + // remove spent outputs + for _, e := range spent { + // TODO: check if rows affected is > 0? + if err := u.tx. + Where("output_id", hash256(e.ID)). + Delete(&dbWalletOutput{}). + Error; err != nil { + return err + } + } + + // create outputs + for _, e := range created { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "output_id"}}, + }). + Create(&dbWalletOutput{ + OutputID: hash256(e.ID), + LeafIndex: e.StateElement.LeafIndex, + MerkleProof: merkleProof{proof: e.StateElement.MerkleProof}, + Value: currency(e.SiacoinOutput.Value), + Address: hash256(e.SiacoinOutput.Address), + MaturityHeight: e.MaturityHeight, + Height: index.Height, + BlockID: hash256(index.ID), + }).Error; err != nil { + return nil + } + } + + // create events + for _, e := range events { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "event_id"}}, + }). + Create(&dbWalletEvent{ + EventID: hash256(e.ID), + Inflow: currency(e.Inflow), + Outflow: currency(e.Outflow), + Transaction: e.Transaction, + MaturityHeight: e.MaturityHeight, + Source: string(e.Source), + Timestamp: e.Timestamp.Unix(), + Height: e.Index.Height, + BlockID: hash256(e.Index.ID), + }).Error; err != nil { + return err + } + } + return nil +} + +// Commit commits the updates to the database. +func (u *chainUpdateTx) Commit() error { + return u.tx.Commit().Error +} + +// Rollback rolls back the transaction +func (u *chainUpdateTx) Rollback() error { + return u.tx.Rollback().Error +} + +// ContractState returns the state of a file contract. +func (u *chainUpdateTx) ContractState(fcid types.FileContractID) (api.ContractState, error) { + var state contractState + err := u.tx. + Select("state"). + Model(&dbContract{}). + Where("fcid = ?", fileContractID(fcid)). + Scan(&state). + Error + + if err == gorm.ErrRecordNotFound { + err = u.tx. + Select("state"). + Model(&dbArchivedContract{}). + Where("fcid = ?", fileContractID(fcid)). + Scan(&state). + Error + } + + if err != nil { + return "", err + } + return api.ContractState(state.String()), nil +} + +// RemoveSiacoinElements is called with all siacoin elements that were spent in +// the update. +func (u *chainUpdateTx) RemoveSiacoinElements(ids []types.SiacoinOutputID) error { + for _, id := range ids { + if err := u.tx. + Where("output_id", hash256(id)). + Delete(&dbWalletOutput{}). + Error; err != nil { + return err + } + } + return nil +} + +// RevertIndex is called with the chain index that is being reverted. Any +// transactions and siacoin elements that were created by the index should be +// removed. +func (u *chainUpdateTx) RevertIndex(index types.ChainIndex, removed, unspent []types.SiacoinElement) error { + // recreate unspent outputs + for _, e := range unspent { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "output_id"}}, + }). + Create(&dbWalletOutput{ + OutputID: hash256(e.ID), + LeafIndex: e.StateElement.LeafIndex, + MerkleProof: merkleProof{proof: e.StateElement.MerkleProof}, + Value: currency(e.SiacoinOutput.Value), + Address: hash256(e.SiacoinOutput.Address), + MaturityHeight: e.MaturityHeight, + Height: index.Height, + BlockID: hash256(index.ID), + }).Error; err != nil { + return nil + } + } + + // remove outputs created at the reverted index + for _, e := range removed { + if err := u.tx. + Where("output_id", hash256(e.ID)). + Delete(&dbWalletOutput{}). + Error; err != nil { + return err + } + } + + // remove events created at the reverted index + return u.tx. + Model(&dbWalletEvent{}). + Where("height = ? AND block_id = ?", index.Height, hash256(index.ID)). + Delete(&dbWalletEvent{}). + Error +} + +// UpdateChainIndex updates the chain index in the database. +func (u *chainUpdateTx) UpdateChainIndex(index types.ChainIndex) error { + return u.tx. + Model(&dbConsensusInfo{}). + Where(&dbConsensusInfo{Model: Model{ID: consensusInfoID}}). + Updates(map[string]interface{}{ + "height": index.Height, + "block_id": hash256(index.ID), + }). + Error +} + +// UpdateContract updates the revision height, revision number, and size the +// contract with given fcid. +func (u *chainUpdateTx) UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error { + updates := map[string]interface{}{ + "revision_height": revisionHeight, + } + if revisionNumber > 0 { + updates["revision_number"] = fmt.Sprint(revisionNumber) + } + if size > 0 { + updates["size"] = size + } + + if err := u.tx. + Model(&dbContract{}). + Where("fcid = ?", fileContractID(fcid)). + Updates(updates). + Error; err != nil { + return err + } + return u.tx. + Model(&dbArchivedContract{}). + Where("fcid = ?", fileContractID(fcid)). + Updates(updates). + Error +} + +// UpdateContractState updates the state of the contract with given fcid. +func (u *chainUpdateTx) UpdateContractState(fcid types.FileContractID, state api.ContractState) error { + var cs contractState + if err := cs.LoadString(string(state)); err != nil { + return err + } + + if err := u.tx. + Model(&dbContract{}). + Where("fcid = ?", fileContractID(fcid)). + Update("state", cs). + Error; err != nil { + return err + } + return u.tx. + Model(&dbArchivedContract{}). + Where("fcid = ?", fileContractID(fcid)). + Update("state", cs). + Error +} + +// UpdateContractProofHeight updates the proof height of the contract with given +// fcid. +func (u *chainUpdateTx) UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error { + if err := u.tx. + Model(&dbContract{}). + Where("fcid = ?", fileContractID(fcid)). + Update("proof_height", proofHeight). + Error; err != nil { + return err + } + return u.tx. + Model(&dbArchivedContract{}). + Where("fcid = ?", fileContractID(fcid)). + Update("proof_height", proofHeight). + Error +} + +// UpdateFailedContracts marks active contract as failed if the current +// blockheight surposses their window_end. +func (u *chainUpdateTx) UpdateFailedContracts(blockHeight uint64) error { + return u.tx. + Model(&dbContract{}). + Where("state = ? AND ? > window_end", contractStateActive, blockHeight). + Update("state", contractStateFailed). + Error +} + +// UpdateHost creates the announcement and upserts the host in the database. +func (u *chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error { + // create the announcement + if err := u.tx.Create(&dbAnnouncement{ + HostKey: publicKey(hk), + BlockHeight: bh, + BlockID: blockID.String(), + NetAddress: ha.NetAddress, + }).Error; err != nil { + return err + } + + // create the host + if err := u.tx.Create(&dbHost{ + PublicKey: publicKey(hk), + LastAnnouncement: ts.UTC(), + NetAddress: ha.NetAddress, + }).Error; err != nil { + return err + } + + // fetch blocklists + allowlist, blocklist, err := getBlocklists(u.tx) + if err != nil { + return fmt.Errorf("%w; failed to fetch blocklists", err) + } + + // return early if there are no allowlist or blocklist entries + if len(allowlist)+len(blocklist) == 0 { + return nil + } + + // update blocklist + if err := updateBlocklist(u.tx, hk, allowlist, blocklist); err != nil { + return fmt.Errorf("%w; failed to update blocklist for host %v", err, hk) + } + + return nil +} + +// UpdateStateElements updates the proofs of all state elements affected by the +// update. +func (u *chainUpdateTx) UpdateStateElements(elements []types.StateElement) error { + for _, se := range elements { + if err := u.tx. + Model(&dbWalletOutput{}). + Where("output_id", hash256(se.ID)). + Updates(map[string]interface{}{ + "merkle_proof": merkleProof{proof: se.MerkleProof}, + "leaf_index": se.LeafIndex, + }).Error; err != nil { + return err + } + } + return nil +} + +// WalletStateElements implements the ChainStore interface and returns all state +// elements in the database. +func (u *chainUpdateTx) WalletStateElements() ([]types.StateElement, error) { + type row struct { + ID hash256 + LeafIndex uint64 + MerkleProof merkleProof + } + var rows []row + if err := u.tx. + Model(&dbWalletOutput{}). + Select("output_id AS id", "leaf_index", "merkle_proof"). + Find(&rows). + Error; err != nil { + return nil, err + } + elements := make([]types.StateElement, 0, len(rows)) + for _, r := range rows { + elements = append(elements, types.StateElement{ + ID: types.Hash256(r.ID), + LeafIndex: r.LeafIndex, + MerkleProof: r.MerkleProof.proof, + }) + } + return elements, nil +} diff --git a/stores/chain_test.go b/stores/chain_test.go new file mode 100644 index 0000000000..f4f6104eb2 --- /dev/null +++ b/stores/chain_test.go @@ -0,0 +1,131 @@ +package stores + +import ( + "context" + "testing" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/chain" +) + +// TestChainUpdateTx tests the chain update transaction. +func TestChainUpdateTx(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + + // add test host and contract + hks, err := ss.addTestHosts(1) + if err != nil { + t.Fatal(err) + } + fcids, _, err := ss.addTestContracts(hks) + if err != nil { + t.Fatal(err) + } else if len(fcids) != 1 { + t.Fatal("expected one contract", len(fcids)) + } + fcid := fcids[0] + + // assert commit with no changes is successful + tx, err := ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } + + // assert rollback with no changes is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + if err := tx.Rollback(); err != nil { + t.Fatal("unexpected error", err) + } + + // assert contract state returns the correct state + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + state, err := tx.ContractState(fcid) + if err != nil { + t.Fatal("unexpected error", err) + } else if state != api.ContractStatePending { + t.Fatal("expected pending state", state) + } + + // assert update chain index is successful + if curr, err := ss.ChainIndex(); err != nil { + t.Fatal("unexpected error", err) + } else if curr.Height != 0 { + t.Fatal("unexpected height", curr.Height) + } + index := types.ChainIndex{Height: 1} + if err := tx.UpdateChainIndex(index); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } + if got, err := ss.ChainIndex(); err != nil { + t.Fatal("unexpected error", err) + } else if got.Height != index.Height { + t.Fatal("unexpected height", got.Height) + } + + // assert update contract is successful + var we uint64 + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContract(fcid, 1, 2, 3); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContractState(fcid, api.ContractStateActive); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContractProofHeight(fcid, 4); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if c, err := ss.contract(context.Background(), fileContractID(fcid)); err != nil { + t.Fatal("unexpected error", err) + } else if c.RevisionHeight != 1 { + t.Fatal("unexpected revision height", c.RevisionHeight) + } else if c.RevisionNumber != "2" { + t.Fatal("unexpected revision number", c.RevisionNumber) + } else if c.Size != 3 { + t.Fatal("unexpected size", c.Size) + } else if c.State.String() != api.ContractStateActive { + t.Fatal("unexpected state", c.State) + } else { + we = c.WindowEnd + } + + // assert update failed contracts is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateFailedContracts(we + 1); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if c, err := ss.contract(context.Background(), fileContractID(fcid)); err != nil { + t.Fatal("unexpected error", err) + } else if c.State.String() != api.ContractStateFailed { + t.Fatal("unexpected state", c.State) + } + + // assert update host is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateHost(hks[0], chain.HostAnnouncement{NetAddress: "foo"}, 1, types.BlockID{}, types.CurrentTimestamp()); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if h, err := ss.Host(context.Background(), hks[0]); err != nil { + t.Fatal("unexpected error", err) + } else if h.NetAddress != "foo" { + t.Fatal("unexpected net address", h.NetAddress) + } +} diff --git a/stores/hostdb.go b/stores/hostdb.go index 5ca6b3c347..8f70e1f61d 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -1160,6 +1160,26 @@ func updateActiveAndArchivedContract(tx *gorm.DB, fcid types.FileContractID, upd return nil } +func getBlocklists(tx *gorm.DB) ([]dbAllowlistEntry, []dbBlocklistEntry, error) { + var allowlist []dbAllowlistEntry + if err := tx. + Model(&dbAllowlistEntry{}). + Find(&allowlist). + Error; err != nil { + return nil, nil, err + } + + var blocklist []dbBlocklistEntry + if err := tx. + Model(&dbBlocklistEntry{}). + Find(&blocklist). + Error; err != nil { + return nil, nil, err + } + + return allowlist, blocklist, nil +} + func updateBlocklist(tx *gorm.DB, hk types.PublicKey, allowlist []dbAllowlistEntry, blocklist []dbBlocklistEntry) error { // fetch the host var host dbHost