diff --git a/api/contract.go b/api/contract.go index 94f8c998a..92775b268 100644 --- a/api/contract.go +++ b/api/contract.go @@ -32,6 +32,8 @@ var ( ErrContractSetNotFound = errors.New("couldn't find contract set") ) +type ContractState string + type ( // A Contract wraps the contract metadata with the latest contract revision. Contract struct { diff --git a/chain/manager.go b/chain/manager.go new file mode 100644 index 000000000..5614604da --- /dev/null +++ b/chain/manager.go @@ -0,0 +1,10 @@ +package chain + +import ( + "go.sia.tech/coreutils/chain" +) + +type ( + Manager = chain.Manager + HostAnnouncement = chain.HostAnnouncement +) diff --git a/chain/subscriber.go b/chain/subscriber.go new file mode 100644 index 000000000..5b3da7e70 --- /dev/null +++ b/chain/subscriber.go @@ -0,0 +1,29 @@ +package chain + +import ( + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/chain" + "go.sia.tech/renterd/api" +) + +type ( + ChainStore interface { + BeginChainUpdateTx() (ChainUpdateTx, error) + ChainIndex() (types.ChainIndex, error) + } + + ChainUpdateTx interface { + Commit() error + Rollback() error + + ContractState(fcid types.FileContractID) (api.ContractState, error) + UpdateChainIndex(index types.ChainIndex) error + UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error + UpdateContractState(fcid types.FileContractID, state api.ContractState) error + UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error + UpdateFailedContracts(blockHeight uint64) error + UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error + } +) diff --git a/internal/test/e2e/cluster.go b/internal/test/e2e/cluster.go index 656736c2e..9e02d7c17 100644 --- a/internal/test/e2e/cluster.go +++ b/internal/test/e2e/cluster.go @@ -459,6 +459,7 @@ func newTestCluster(t *testing.T, opts testClusterOptions) *TestCluster { if nHosts > 0 { cluster.AddHostsBlocking(nHosts) + cluster.WaitForPeers() cluster.WaitForContracts() cluster.WaitForContractSet(test.ContractSet, nHosts) cluster.WaitForAccounts() @@ -657,6 +658,19 @@ func (c *TestCluster) WaitForContractSetContracts(set string, n int) { }) } +func (c *TestCluster) WaitForPeers() { + c.tt.Helper() + c.tt.Retry(300, 100*time.Millisecond, func() error { + peers, err := c.Bus.SyncerPeers(context.Background()) + if err != nil { + return err + } else if len(peers) == 0 { + return errors.New("no peers found") + } + return nil + }) +} + func (c *TestCluster) RemoveHost(host *Host) { c.tt.Helper() c.tt.OK(host.Close()) @@ -687,15 +701,15 @@ func (c *TestCluster) AddHost(h *Host) { c.hosts = append(c.hosts, h) // Fund host from bus. - fundAmt := types.Siacoins(100e3) + fundAmt := types.Siacoins(25e3) var scos []types.SiacoinOutput for i := 0; i < 10; i++ { scos = append(scos, types.SiacoinOutput{ - Value: fundAmt, + Value: fundAmt.Div64(10), Address: h.WalletAddress(), }) } - c.tt.OK(c.Bus.SendSiacoins(context.Background(), scos, false)) + c.tt.OK(c.Bus.SendSiacoins(context.Background(), scos, true)) // Mine transaction. c.MineBlocks(1) diff --git a/internal/test/e2e/pruning_test.go b/internal/test/e2e/pruning_test.go index d2455b79f..8492bf9f1 100644 --- a/internal/test/e2e/pruning_test.go +++ b/internal/test/e2e/pruning_test.go @@ -12,7 +12,6 @@ import ( "go.sia.tech/core/types" "go.sia.tech/renterd/api" "go.sia.tech/renterd/internal/test" - "go.uber.org/zap/zapcore" ) func TestHostPruning(t *testing.T) { @@ -21,14 +20,11 @@ func TestHostPruning(t *testing.T) { } // create a new test cluster - opts := clusterOptsDefault - opts.logger = newTestLoggerCustom(zapcore.DebugLevel) - cluster := newTestCluster(t, opts) + cluster := newTestCluster(t, testClusterOptions{hosts: 1}) defer cluster.Shutdown() // convenience variables b := cluster.Bus - w := cluster.Worker a := cluster.Autopilot tt := cluster.tt @@ -48,26 +44,13 @@ func TestHostPruning(t *testing.T) { tt.OK(b.RecordHostScans(context.Background(), his)) } - // add a host - hosts := cluster.AddHosts(1) - h1 := hosts[0] - - // fetch the host - h, err := b.Host(context.Background(), h1.PublicKey()) - tt.OK(err) - - // scan the host (lastScan needs to be > 0 for downtime to start counting) - tt.OKAll(w.RHPScan(context.Background(), h1.PublicKey(), h.NetAddress, 0)) - - // block the host - tt.OK(b.UpdateHostBlocklist(context.Background(), []string{h1.PublicKey().String()}, nil, false)) + // shut down the worker manually, this will flush any interactions + cluster.ShutdownWorker(context.Background()) // remove it from the cluster manually + h1 := cluster.hosts[0] cluster.RemoveHost(h1) - // shut down the worker manually, this will flush any interactions - cluster.ShutdownWorker(context.Background()) - // record 9 failed interactions, right before the pruning threshold, and // wait for the autopilot loop to finish at least once recordFailedInteractions(9, h1.PublicKey()) diff --git a/stores/chain.go b/stores/chain.go new file mode 100644 index 000000000..2920c8b37 --- /dev/null +++ b/stores/chain.go @@ -0,0 +1,376 @@ +package stores + +import ( + "fmt" + "time" + + "go.sia.tech/core/types" + "go.sia.tech/coreutils/wallet" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/chain" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +var ( + _ chain.ChainStore = (*SQLStore)(nil) + _ chain.ChainUpdateTx = (*chainUpdateTx)(nil) +) + +// chainUpdateTx implements the ChainUpdateTx interface. +type chainUpdateTx struct { + tx *gorm.DB +} + +// BeginChainUpdateTx starts a transaction and wraps it in a chainUpdateTx. This +// transaction will be used to process a chain update in the subscriber. +func (s *SQLStore) BeginChainUpdateTx() (chain.ChainUpdateTx, error) { + tx := s.db.Begin() + if tx.Error != nil { + return nil, tx.Error + } + return &chainUpdateTx{tx: tx}, nil +} + +// ApplyIndex is called with the chain index that is being applied. Any +// transactions and siacoin elements that were created by the index should be +// added and any siacoin elements that were spent should be removed. +func (u *chainUpdateTx) ApplyIndex(index types.ChainIndex, created, spent []types.SiacoinElement, events []wallet.Event) error { + // remove spent outputs + for _, e := range spent { + if res := u.tx. + Where("output_id", hash256(e.ID)). + Delete(&dbWalletOutput{}); res.Error != nil { + return res.Error + } else if res.RowsAffected != 1 { + return fmt.Errorf("spent output with id %v not found ", e.ID) + } + } + + // create outputs + for _, e := range created { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "output_id"}}, + }). + Create(&dbWalletOutput{ + OutputID: hash256(e.ID), + LeafIndex: e.StateElement.LeafIndex, + MerkleProof: merkleProof{proof: e.StateElement.MerkleProof}, + Value: currency(e.SiacoinOutput.Value), + Address: hash256(e.SiacoinOutput.Address), + MaturityHeight: e.MaturityHeight, + Height: index.Height, + BlockID: hash256(index.ID), + }).Error; err != nil { + return nil + } + } + + // create events + for _, e := range events { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "event_id"}}, + }). + Create(&dbWalletEvent{ + EventID: hash256(e.ID), + Inflow: currency(e.Inflow), + Outflow: currency(e.Outflow), + Transaction: e.Transaction, + MaturityHeight: e.MaturityHeight, + Source: string(e.Source), + Timestamp: e.Timestamp.Unix(), + Height: e.Index.Height, + BlockID: hash256(e.Index.ID), + }).Error; err != nil { + return err + } + } + return nil +} + +// Commit commits the updates to the database. +func (u *chainUpdateTx) Commit() error { + return u.tx.Commit().Error +} + +// Rollback rolls back the transaction +func (u *chainUpdateTx) Rollback() error { + return u.tx.Rollback().Error +} + +// ContractState returns the state of a file contract. +func (u *chainUpdateTx) ContractState(fcid types.FileContractID) (api.ContractState, error) { + var state contractState + err := u.tx. + Select("state"). + Model(&dbContract{}). + Where("fcid", fileContractID(fcid)). + Scan(&state). + Error + + if err == gorm.ErrRecordNotFound { + err = u.tx. + Select("state"). + Model(&dbArchivedContract{}). + Where("fcid", fileContractID(fcid)). + Scan(&state). + Error + } + + if err != nil { + return "", err + } + return api.ContractState(state.String()), nil +} + +// RemoveSiacoinElements is called with all siacoin elements that were spent in +// the update. +func (u *chainUpdateTx) RemoveSiacoinElements(ids []types.SiacoinOutputID) error { + for _, id := range ids { + if err := u.tx. + Where("output_id", hash256(id)). + Delete(&dbWalletOutput{}). + Error; err != nil { + return err + } + } + return nil +} + +// RevertIndex is called with the chain index that is being reverted. Any +// transactions and siacoin elements that were created by the index should be +// removed. +func (u *chainUpdateTx) RevertIndex(index types.ChainIndex, removed, unspent []types.SiacoinElement) error { + // recreate unspent outputs + for _, e := range unspent { + if err := u.tx. + Clauses(clause.OnConflict{ + DoNothing: true, + Columns: []clause.Column{{Name: "output_id"}}, + }). + Create(&dbWalletOutput{ + OutputID: hash256(e.ID), + LeafIndex: e.StateElement.LeafIndex, + MerkleProof: merkleProof{proof: e.StateElement.MerkleProof}, + Value: currency(e.SiacoinOutput.Value), + Address: hash256(e.SiacoinOutput.Address), + MaturityHeight: e.MaturityHeight, + Height: index.Height, + BlockID: hash256(index.ID), + }).Error; err != nil { + return nil + } + } + + // remove outputs created at the reverted index + for _, e := range removed { + if err := u.tx. + Where("output_id", hash256(e.ID)). + Delete(&dbWalletOutput{}). + Error; err != nil { + return err + } + } + + // remove events created at the reverted index + return u.tx. + Model(&dbWalletEvent{}). + Where("height = ? AND block_id = ?", index.Height, hash256(index.ID)). + Delete(&dbWalletEvent{}). + Error +} + +// UpdateChainIndex updates the chain index in the database. +func (u *chainUpdateTx) UpdateChainIndex(index types.ChainIndex) error { + return u.tx. + Model(&dbConsensusInfo{}). + Where(&dbConsensusInfo{Model: Model{ID: consensusInfoID}}). + Updates(map[string]interface{}{ + "height": index.Height, + "block_id": hash256(index.ID), + }). + Error +} + +// UpdateContract updates the revision height, revision number, and size the +// contract with given fcid. +func (u *chainUpdateTx) UpdateContract(fcid types.FileContractID, revisionHeight, revisionNumber, size uint64) error { + // isUpdatedRevision indicates whether the given revision number is greater + // than the one currently set on the contract + isUpdatedRevision := func(currRevStr string) bool { + var currRev uint64 + _, _ = fmt.Sscan(currRevStr, &currRev) + return revisionNumber > currRev + } + + // update either active or archived contract + var update interface{} + var c dbContract + if err := u.tx. + Model(&dbContract{}). + Where("fcid", fileContractID(fcid)). + Take(&c).Error; err == nil { + c.RevisionHeight = revisionHeight + if isUpdatedRevision(c.RevisionNumber) { + c.RevisionNumber = fmt.Sprint(revisionNumber) + c.Size = size + } + update = c + } else if err == gorm.ErrRecordNotFound { + // try archived contracts + var ac dbArchivedContract + if err := u.tx. + Model(&dbArchivedContract{}). + Where("fcid", fileContractID(fcid)). + Take(&ac).Error; err == nil { + ac.RevisionHeight = revisionHeight + if isUpdatedRevision(ac.RevisionNumber) { + ac.RevisionNumber = fmt.Sprint(revisionNumber) + ac.Size = size + } + update = ac + } + } + if update == nil { + return nil + } + + return u.tx.Save(update).Error +} + +// UpdateContractState updates the state of the contract with given fcid. +func (u *chainUpdateTx) UpdateContractState(fcid types.FileContractID, state api.ContractState) error { + var cs contractState + if err := cs.LoadString(string(state)); err != nil { + return err + } + + if err := u.tx. + Model(&dbContract{}). + Where("fcid", fileContractID(fcid)). + Update("state", cs). + Error; err != nil { + return err + } + return u.tx. + Model(&dbArchivedContract{}). + Where("fcid", fileContractID(fcid)). + Update("state", cs). + Error +} + +// UpdateContractProofHeight updates the proof height of the contract with given +// fcid. +func (u *chainUpdateTx) UpdateContractProofHeight(fcid types.FileContractID, proofHeight uint64) error { + if err := u.tx. + Model(&dbContract{}). + Where("fcid", fileContractID(fcid)). + Update("proof_height", proofHeight). + Error; err != nil { + return err + } + return u.tx. + Model(&dbArchivedContract{}). + Where("fcid", fileContractID(fcid)). + Update("proof_height", proofHeight). + Error +} + +// UpdateFailedContracts marks active contract as failed if the current +// blockheight surposses their window_end. +func (u *chainUpdateTx) UpdateFailedContracts(blockHeight uint64) error { + return u.tx. + Model(&dbContract{}). + Where("window_end <= ?", blockHeight). + Where("state", contractStateActive). + Update("state", contractStateFailed). + Error +} + +// UpdateHost creates the announcement and upserts the host in the database. +func (u *chainUpdateTx) UpdateHost(hk types.PublicKey, ha chain.HostAnnouncement, bh uint64, blockID types.BlockID, ts time.Time) error { + // create the announcement + if err := u.tx.Create(&dbAnnouncement{ + HostKey: publicKey(hk), + BlockHeight: bh, + BlockID: blockID.String(), + NetAddress: ha.NetAddress, + }).Error; err != nil { + return err + } + + // create the host + if err := u.tx.Create(&dbHost{ + PublicKey: publicKey(hk), + LastAnnouncement: ts.UTC(), + NetAddress: ha.NetAddress, + }).Error; err != nil { + return err + } + + // fetch blocklists + allowlist, blocklist, err := getBlocklists(u.tx) + if err != nil { + return fmt.Errorf("%w; failed to fetch blocklists", err) + } + + // return early if there are no allowlist or blocklist entries + if len(allowlist)+len(blocklist) == 0 { + return nil + } + + // update blocklist + if err := updateBlocklist(u.tx, hk, allowlist, blocklist); err != nil { + return fmt.Errorf("%w; failed to update blocklist for host %v", err, hk) + } + + return nil +} + +// UpdateStateElements updates the proofs of all state elements affected by the +// update. +func (u *chainUpdateTx) UpdateStateElements(elements []types.StateElement) error { + for _, se := range elements { + if err := u.tx. + Model(&dbWalletOutput{}). + Where("output_id", hash256(se.ID)). + Updates(map[string]interface{}{ + "merkle_proof": merkleProof{proof: se.MerkleProof}, + "leaf_index": se.LeafIndex, + }).Error; err != nil { + return err + } + } + return nil +} + +// WalletStateElements implements the ChainStore interface and returns all state +// elements in the database. +func (u *chainUpdateTx) WalletStateElements() ([]types.StateElement, error) { + type row struct { + ID hash256 + LeafIndex uint64 + MerkleProof merkleProof + } + var rows []row + if err := u.tx. + Model(&dbWalletOutput{}). + Select("output_id AS id", "leaf_index", "merkle_proof"). + Find(&rows). + Error; err != nil { + return nil, err + } + elements := make([]types.StateElement, 0, len(rows)) + for _, r := range rows { + elements = append(elements, types.StateElement{ + ID: types.Hash256(r.ID), + LeafIndex: r.LeafIndex, + MerkleProof: r.MerkleProof.proof, + }) + } + return elements, nil +} diff --git a/stores/chain_test.go b/stores/chain_test.go new file mode 100644 index 000000000..11bada2c4 --- /dev/null +++ b/stores/chain_test.go @@ -0,0 +1,149 @@ +package stores + +import ( + "context" + "testing" + + "go.sia.tech/core/types" + "go.sia.tech/renterd/api" + "go.sia.tech/renterd/chain" +) + +// TestChainUpdateTx tests the chain update transaction. +func TestChainUpdateTx(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + + // add test host and contract + hks, err := ss.addTestHosts(1) + if err != nil { + t.Fatal(err) + } + fcids, _, err := ss.addTestContracts(hks) + if err != nil { + t.Fatal(err) + } else if len(fcids) != 1 { + t.Fatal("expected one contract", len(fcids)) + } + fcid := fcids[0] + + // assert commit with no changes is successful + tx, err := ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } + + // assert rollback with no changes is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + if err := tx.Rollback(); err != nil { + t.Fatal("unexpected error", err) + } + + // assert contract state returns the correct state + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } + state, err := tx.ContractState(fcid) + if err != nil { + t.Fatal("unexpected error", err) + } else if state != api.ContractStatePending { + t.Fatal("expected pending state", state) + } + + // assert update chain index is successful + if curr, err := ss.ChainIndex(); err != nil { + t.Fatal("unexpected error", err) + } else if curr.Height != 0 { + t.Fatal("unexpected height", curr.Height) + } + index := types.ChainIndex{Height: 1} + if err := tx.UpdateChainIndex(index); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } + if got, err := ss.ChainIndex(); err != nil { + t.Fatal("unexpected error", err) + } else if got.Height != index.Height { + t.Fatal("unexpected height", got.Height) + } + + // assert update contract is successful + var we uint64 + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContract(fcid, 1, 2, 3); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContractState(fcid, api.ContractStateActive); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContractProofHeight(fcid, 4); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if c, err := ss.contract(context.Background(), fileContractID(fcid)); err != nil { + t.Fatal("unexpected error", err) + } else if c.RevisionHeight != 1 { + t.Fatal("unexpected revision height", c.RevisionHeight) + } else if c.RevisionNumber != "2" { + t.Fatal("unexpected revision number", c.RevisionNumber) + } else if c.Size != 3 { + t.Fatal("unexpected size", c.Size) + } else if c.State.String() != api.ContractStateActive { + t.Fatal("unexpected state", c.State) + } else { + we = c.WindowEnd + } + + // assert we only update revision height if the rev number doesn't increase + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateContract(fcid, 2, 2, 4); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if c, err := ss.contract(context.Background(), fileContractID(fcid)); err != nil { + t.Fatal("unexpected error", err) + } else if c.RevisionHeight != 2 { + t.Fatal("unexpected revision height", c.RevisionHeight) + } else if c.RevisionNumber != "2" { + t.Fatal("unexpected revision number", c.RevisionNumber) + } else if c.Size != 3 { + t.Fatal("unexpected size", c.Size) + } + + // assert update failed contracts is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateFailedContracts(we + 1); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if c, err := ss.contract(context.Background(), fileContractID(fcid)); err != nil { + t.Fatal("unexpected error", err) + } else if c.State.String() != api.ContractStateFailed { + t.Fatal("unexpected state", c.State) + } + + // assert update host is successful + tx, err = ss.BeginChainUpdateTx() + if err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.UpdateHost(hks[0], chain.HostAnnouncement{NetAddress: "foo"}, 1, types.BlockID{}, types.CurrentTimestamp()); err != nil { + t.Fatal("unexpected error", err) + } else if err := tx.Commit(); err != nil { + t.Fatal("unexpected error", err) + } else if h, err := ss.Host(context.Background(), hks[0]); err != nil { + t.Fatal("unexpected error", err) + } else if h.NetAddress != "foo" { + t.Fatal("unexpected net address", h.NetAddress) + } +} diff --git a/stores/hostdb.go b/stores/hostdb.go index 5ca6b3c34..ae3b47e47 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -1119,11 +1119,46 @@ func insertAnnouncements(tx *gorm.DB, as []announcement) error { } func applyRevisionUpdate(db *gorm.DB, fcid types.FileContractID, rev revisionUpdate) error { - return updateActiveAndArchivedContract(db, fcid, map[string]interface{}{ - "revision_height": rev.height, - "revision_number": fmt.Sprint(rev.number), - "size": rev.size, - }) + // isUpdatedRevision indicates whether the given revision number is greater + // than the one currently set on the contract + isUpdatedRevision := func(currRevStr string) bool { + var currRev uint64 + _, _ = fmt.Sscan(currRevStr, &currRev) + return rev.number > currRev + } + + // update either active or archived contract + var update interface{} + var c dbContract + if err := db. + Model(&dbContract{}). + Where("fcid", fileContractID(fcid)). + Take(&c).Error; err == nil { + c.RevisionHeight = rev.height + if isUpdatedRevision(c.RevisionNumber) { + c.RevisionNumber = fmt.Sprint(rev.number) + c.Size = rev.size + } + update = c + } else if err == gorm.ErrRecordNotFound { + // try archived contracts + var ac dbArchivedContract + if err := db. + Model(&dbArchivedContract{}). + Where("fcid", fileContractID(fcid)). + Take(&ac).Error; err == nil { + ac.RevisionHeight = rev.height + if isUpdatedRevision(ac.RevisionNumber) { + ac.RevisionNumber = fmt.Sprint(rev.number) + ac.Size = rev.size + } + update = ac + } + } + if update == nil { + return nil + } + return db.Save(update).Error } func updateContractState(db *gorm.DB, fcid types.FileContractID, cs contractState) error { @@ -1149,10 +1184,10 @@ func updateProofHeight(db *gorm.DB, fcid types.FileContractID, blockHeight uint6 func updateActiveAndArchivedContract(tx *gorm.DB, fcid types.FileContractID, updates map[string]interface{}) error { err1 := tx.Model(&dbContract{}). - Where("fcid = ?", fileContractID(fcid)). + Where("fcid", fileContractID(fcid)). Updates(updates).Error err2 := tx.Model(&dbArchivedContract{}). - Where("fcid = ?", fileContractID(fcid)). + Where("fcid", fileContractID(fcid)). Updates(updates).Error if err1 != nil || err2 != nil { return fmt.Errorf("%s; %s", err1, err2) @@ -1160,6 +1195,26 @@ func updateActiveAndArchivedContract(tx *gorm.DB, fcid types.FileContractID, upd return nil } +func getBlocklists(tx *gorm.DB) ([]dbAllowlistEntry, []dbBlocklistEntry, error) { + var allowlist []dbAllowlistEntry + if err := tx. + Model(&dbAllowlistEntry{}). + Find(&allowlist). + Error; err != nil { + return nil, nil, err + } + + var blocklist []dbBlocklistEntry + if err := tx. + Model(&dbBlocklistEntry{}). + Find(&blocklist). + Error; err != nil { + return nil, nil, err + } + + return allowlist, blocklist, nil +} + func updateBlocklist(tx *gorm.DB, hk types.PublicKey, allowlist []dbAllowlistEntry, blocklist []dbBlocklistEntry) error { // fetch the host var host dbHost diff --git a/stores/metadata.go b/stores/metadata.go index a0f31cfe8..f168f48e5 100644 --- a/stores/metadata.go +++ b/stores/metadata.go @@ -1389,7 +1389,7 @@ func (s *SQLStore) RecordContractSpending(ctx context.Context, records []api.Con err := s.retryTransaction(ctx, func(tx *gorm.DB) error { var contract dbContract err := tx.Model(&dbContract{}). - Where("fcid = ?", fileContractID(fcid)). + Where("fcid", fileContractID(fcid)). Joins("Host"). Take(&contract).Error if errors.Is(err, gorm.ErrRecordNotFound) {