diff --git a/internal/bus/accounts.go b/internal/bus/accounts.go index 3586250cb..666398e57 100644 --- a/internal/bus/accounts.go +++ b/internal/bus/accounts.go @@ -182,7 +182,18 @@ func (a *AccountMgr) ResetDrift(id rhpv3.Account) error { return ErrAccountNotFound } a.mu.Unlock() + + account.mu.Lock() + driftBefore := account.Drift.String() + account.mu.Unlock() + account.resetDrift() + + a.logger.Infow("account drift was reset", + zap.Stringer("account", account.ID), + zap.Stringer("host", account.HostKey), + zap.String("driftBefore", driftBefore)) + return nil } @@ -193,29 +204,37 @@ func (a *AccountMgr) ResetDrift(id rhpv3.Account) error { func (a *AccountMgr) SetBalance(id rhpv3.Account, hk types.PublicKey, balance *big.Int) { acc := a.account(id, hk) - // Update balance and drift. acc.mu.Lock() - delta := new(big.Int).Sub(balance, acc.Balance) - balanceBefore := acc.Balance.String() - driftBefore := acc.Drift.String() + defer acc.mu.Unlock() + + // save previous values + prevBalance := new(big.Int).Set(acc.Balance) + prevDrift := new(big.Int).Set(acc.Drift) + + // update balance + acc.Balance.Set(balance) + + // update drift + drift := new(big.Int).Sub(balance, prevBalance) if acc.CleanShutdown { - acc.Drift = acc.Drift.Add(acc.Drift, delta) + acc.Drift = acc.Drift.Add(acc.Drift, drift) } - acc.Balance.Set(balance) + + // reset fields acc.CleanShutdown = true - acc.RequiresSync = false // resetting the balance resets the sync field - balanceAfter := acc.Balance.String() - acc.mu.Unlock() + acc.RequiresSync = false - // Log resets. + // log account changes a.logger.Infow("account balance was reset", - "account", acc.ID, - "host", acc.HostKey.String(), - "balanceBefore", balanceBefore, - "balanceAfter", balanceAfter, - "driftBefore", driftBefore, - "driftAfter", acc.Drift.String(), - "delta", delta.String()) + zap.Stringer("account", acc.ID), + zap.Stringer("host", acc.HostKey), + zap.Stringer("balanceBefore", prevBalance), + zap.Stringer("balanceAfter", balance), + zap.Stringer("driftBefore", prevDrift), + zap.Stringer("driftAfter", acc.Drift), + zap.Bool("firstDrift", acc.Drift.Cmp(big.NewInt(0)) != 0 && prevDrift.Cmp(big.NewInt(0)) == 0), + zap.Bool("cleanshutdown", acc.CleanShutdown), + zap.Stringer("drift", drift)) } // ScheduleSync sets the requiresSync flag of an account. @@ -296,7 +315,6 @@ func (a *AccountMgr) account(id rhpv3.Account, hk types.PublicKey) *account { a.mu.Lock() defer a.mu.Unlock() - // Create account if it doesn't exist. acc, exists := a.byID[id] if !exists { acc = &account{ @@ -306,7 +324,7 @@ func (a *AccountMgr) account(id rhpv3.Account, hk types.PublicKey) *account { HostKey: hk, Balance: big.NewInt(0), Drift: big.NewInt(0), - RequiresSync: false, + RequiresSync: true, // initial sync }, locks: map[uint64]*accountLock{}, } diff --git a/internal/rhp/v3/rhp.go b/internal/rhp/v3/rhp.go index 7b8629861..5ae5d9972 100644 --- a/internal/rhp/v3/rhp.go +++ b/internal/rhp/v3/rhp.go @@ -167,9 +167,8 @@ func (c *Client) Renew(ctx context.Context, rrr api.RHPRenewRequest, gougingChec return } -func (c *Client) SyncAccount(ctx context.Context, rev *types.FileContractRevision, hk types.PublicKey, siamuxAddr string, accID rhpv3.Account, pt rhpv3.SettingsID, rk types.PrivateKey) (types.Currency, error) { - var balance types.Currency - err := c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { +func (c *Client) SyncAccount(ctx context.Context, rev *types.FileContractRevision, hk types.PublicKey, siamuxAddr string, accID rhpv3.Account, pt rhpv3.SettingsID, rk types.PrivateKey) (balance types.Currency, _ error) { + return balance, c.tpool.withTransport(ctx, hk, siamuxAddr, func(ctx context.Context, t *transportV3) error { payment, err := payByContract(rev, types.NewCurrency64(1), accID, rk) if err != nil { return err @@ -177,7 +176,6 @@ func (c *Client) SyncAccount(ctx context.Context, rev *types.FileContractRevisio balance, err = rpcAccountBalance(ctx, t, &payment, accID, pt) return err }) - return balance, err } func (c *Client) PriceTable(ctx context.Context, hk types.PublicKey, siamuxAddr string, paymentFn PriceTablePaymentFunc) (pt api.HostPriceTable, err error) { @@ -207,6 +205,7 @@ func (c *Client) ReadSector(ctx context.Context, offset, length uint32, root typ return err } + amount = cost // pessimistic cost estimate in case rpc fails payment := rhpv3.PayByEphemeralAccount(accID, cost, pt.HostBlockHeight+defaultWithdrawalExpiryBlocks, accKey) cost, refund, err := rpcReadSector(ctx, t, w, pt, &payment, offset, length, root) if err != nil { diff --git a/internal/sql/migrations.go b/internal/sql/migrations.go index 18214c07f..377bf6fc5 100644 --- a/internal/sql/migrations.go +++ b/internal/sql/migrations.go @@ -199,6 +199,12 @@ var ( return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00014_hosts_resolvedaddresses", log) }, }, + { + ID: "00015_reset_drift", + Migrate: func(tx Tx) error { + return performMigration(ctx, tx, migrationsFs, dbIdentifier, "00015_reset_drift", log) + }, + }, } } MetricsMigrations = func(ctx context.Context, migrationsFs embed.FS, log *zap.SugaredLogger) []Migration { diff --git a/internal/test/e2e/cluster_test.go b/internal/test/e2e/cluster_test.go index 81b69f6d7..f9ba9e018 100644 --- a/internal/test/e2e/cluster_test.go +++ b/internal/test/e2e/cluster_test.go @@ -1131,54 +1131,64 @@ func TestEphemeralAccounts(t *testing.T) { t.SkipNow() } - // Create cluster - cluster := newTestCluster(t, testClusterOptions{hosts: 1}) + // run without autopilot + opts := clusterOptsDefault + opts.skipRunningAutopilot = true + + // create cluster + cluster := newTestCluster(t, opts) defer cluster.Shutdown() + + // convenience variables + b := cluster.Bus + w := cluster.Worker tt := cluster.tt - // Shut down the autopilot to prevent it from interfering. - cluster.ShutdownAutopilot(context.Background()) + tt.OK(b.UpdateSetting(context.Background(), api.SettingRedundancy, api.RedundancySettings{ + MinShards: 1, + TotalShards: 1, + })) + // add a host + hosts := cluster.AddHosts(1) + h, err := b.Host(context.Background(), hosts[0].PublicKey()) + tt.OK(err) - // Wait for contract and accounts. - contract := cluster.WaitForContracts()[0] - accounts := cluster.WaitForAccounts() + // scan the host + tt.OKAll(w.RHPScan(context.Background(), h.PublicKey, h.NetAddress, 10*time.Second)) - // Shut down the autopilot to prevent it from interfering with the test. - cluster.ShutdownAutopilot(context.Background()) + // manually form a contract with the host + cs, _ := b.ConsensusState(context.Background()) + wallet, _ := b.Wallet(context.Background()) + rev, _, err := w.RHPForm(context.Background(), cs.BlockHeight+test.AutopilotConfig.Contracts.Period+test.AutopilotConfig.Contracts.RenewWindow, h.PublicKey, h.NetAddress, wallet.Address, types.Siacoins(10), types.Siacoins(1)) + tt.OK(err) + c, err := b.AddContract(context.Background(), rev, rev.Revision.MissedHostPayout().Sub(types.Siacoins(1)), types.Siacoins(1), cs.BlockHeight, api.ContractStatePending) + tt.OK(err) - // Newly created accounts are !cleanShutdown. Simulate a sync to change - // that. - for _, acc := range accounts { - if acc.CleanShutdown { - t.Fatal("new account should indicate an unclean shutdown") - } else if acc.RequiresSync { - t.Fatal("new account should not require a sync") - } - if err := cluster.Bus.SetBalance(context.Background(), acc.ID, acc.HostKey, types.Siacoins(1).Big()); err != nil { - t.Fatal(err) - } - } + tt.OK(b.SetContractSet(context.Background(), test.ContractSet, []types.FileContractID{c.ID})) + + // fund the account + fundAmt := types.Siacoins(1) + tt.OK(w.RHPFund(context.Background(), c.ID, c.HostKey, c.HostIP, c.SiamuxAddr, fundAmt)) - // Fetch accounts again. + // fetch accounts accounts, err := cluster.Bus.Accounts(context.Background()) tt.OK(err) + // assert account state acc := accounts[0] - if acc.Balance.Cmp(types.Siacoins(1).Big()) < 0 { - t.Fatalf("wrong balance %v", acc.Balance) - } if acc.ID == (rhpv3.Account{}) { t.Fatal("account id not set") - } - host := cluster.hosts[0] - if acc.HostKey != types.PublicKey(host.PublicKey()) { + } else if acc.CleanShutdown { + t.Fatal("account should indicate an unclean shutdown") + } else if !acc.RequiresSync { + t.Fatal("account should require a sync") + } else if acc.HostKey != h.PublicKey { t.Fatal("wrong host") - } - if !acc.CleanShutdown { - t.Fatal("account should indicate a clean shutdown") + } else if acc.Balance.Cmp(types.Siacoins(1).Big()) != 0 { + t.Fatalf("wrong balance %v", acc.Balance) } - // Fetch account from bus directly. + // fetch account from bus directly busAccounts, err := cluster.Bus.Accounts(context.Background()) tt.OK(err) if len(busAccounts) != 1 { @@ -1189,12 +1199,11 @@ func TestEphemeralAccounts(t *testing.T) { t.Fatal("bus account doesn't match worker account") } - // Check that the spending was recorded for the contract. The recorded + // check that the spending was recorded for the contract. The recorded // spending should be > the fundAmt since it consists of the fundAmt plus // fee. - fundAmt := types.Siacoins(1) tt.Retry(10, testBusFlushInterval, func() error { - cm, err := cluster.Bus.Contract(context.Background(), contract.ID) + cm, err := cluster.Bus.Contract(context.Background(), c.ID) tt.OK(err) if cm.Spending.FundAccount.Cmp(fundAmt) <= 0 { @@ -1203,7 +1212,24 @@ func TestEphemeralAccounts(t *testing.T) { return nil }) - // Update the balance to create some drift. + // sync the account + tt.OK(w.RHPSync(context.Background(), c.ID, acc.HostKey, c.HostIP, c.SiamuxAddr)) + + // assert account state + accounts, err = cluster.Bus.Accounts(context.Background()) + tt.OK(err) + + // assert account state + acc = accounts[0] + if !acc.CleanShutdown { + t.Fatal("account should indicate a clean shutdown") + } else if acc.RequiresSync { + t.Fatal("account should not require a sync") + } else if acc.Drift.Cmp(new(big.Int)) != 0 { + t.Fatalf("account shoult not have drift %v", acc.Drift) + } + + // update the balance to create some drift newBalance := fundAmt.Div64(2) newDrift := new(big.Int).Sub(newBalance.Big(), fundAmt.Big()) if err := cluster.Bus.SetBalance(context.Background(), busAcc.ID, acc.HostKey, newBalance.Big()); err != nil { @@ -1217,11 +1243,11 @@ func TestEphemeralAccounts(t *testing.T) { t.Fatalf("drift was %v but should be %v", busAcc.Drift, maxNewDrift) } - // Reboot cluster. + // reboot cluster cluster2 := cluster.Reboot(t) defer cluster2.Shutdown() - // Check that accounts were loaded from the bus. + // check that accounts were loaded from the bus accounts2, err := cluster2.Bus.Accounts(context.Background()) tt.OK(err) for _, acc := range accounts2 { @@ -1234,7 +1260,7 @@ func TestEphemeralAccounts(t *testing.T) { } } - // Reset drift again. + // reset drift again if err := cluster2.Bus.ResetDrift(context.Background(), acc.ID); err != nil { t.Fatal(err) } diff --git a/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql b/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql index c7657d407..74a5cebe6 100644 --- a/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql +++ b/stores/sql/mysql/migrations/main/migration_00014_hosts_resolvedaddresses.sql @@ -1,3 +1,2 @@ ALTER TABLE hosts DROP COLUMN subnets; ALTER TABLE hosts ADD resolved_addresses varchar(255) NOT NULL DEFAULT ''; - diff --git a/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql b/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql new file mode 100644 index 000000000..c151d90a3 --- /dev/null +++ b/stores/sql/mysql/migrations/main/migration_00015_reset_drift.sql @@ -0,0 +1 @@ +UPDATE ephemeral_accounts SET drift = "0", clean_shutdown = 0, requires_sync = 1; \ No newline at end of file diff --git a/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql b/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql index 126d9d75f..9800c2b7b 100644 --- a/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql +++ b/stores/sql/sqlite/migrations/main/migration_00014_hosts_resolvedaddresses.sql @@ -1,3 +1,2 @@ ALTER TABLE hosts DROP COLUMN subnets; ALTER TABLE hosts ADD resolved_addresses TEXT NOT NULL DEFAULT ''; - diff --git a/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql b/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql new file mode 100644 index 000000000..c151d90a3 --- /dev/null +++ b/stores/sql/sqlite/migrations/main/migration_00015_reset_drift.sql @@ -0,0 +1 @@ +UPDATE ephemeral_accounts SET drift = "0", clean_shutdown = 0, requires_sync = 1; \ No newline at end of file diff --git a/worker/host.go b/worker/host.go index 43cbefbd2..9c65bd0c2 100644 --- a/worker/host.go +++ b/worker/host.go @@ -2,6 +2,7 @@ package worker import ( "context" + "errors" "fmt" "io" "time" @@ -22,7 +23,7 @@ type ( DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) error UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) error - PriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) + PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) PriceTableUnpaid(ctx context.Context) (hpt api.HostPriceTable, err error) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (types.FileContractRevision, error) @@ -78,30 +79,39 @@ func (w *Worker) Host(hk types.PublicKey, fcid types.FileContractID, siamuxAddr func (h *host) PublicKey() types.PublicKey { return h.hk } func (h *host) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) (err error) { - pt, err := h.priceTables.fetch(ctx, h.hk, nil) - if err != nil { - return err - } - hpt := pt.HostPriceTable + var amount types.Currency + return h.acc.WithWithdrawal(ctx, func() (types.Currency, error) { + pt, uptc, err := h.priceTables.fetch(ctx, h.hk, nil) + if err != nil { + return types.ZeroCurrency, err + } + hpt := pt.HostPriceTable + amount = uptc - // check for download gouging specifically - gc, err := GougingCheckerFromContext(ctx, overpay) - if err != nil { - return err - } - if breakdown := gc.Check(nil, &hpt); breakdown.DownloadErr != "" { - return fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown.DownloadErr) - } + // check for download gouging specifically + gc, err := GougingCheckerFromContext(ctx, overpay) + if err != nil { + return amount, err + } + if breakdown := gc.Check(nil, &hpt); breakdown.DownloadErr != "" { + return amount, fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown.DownloadErr) + } - return h.acc.WithWithdrawal(ctx, func() (amount types.Currency, err error) { - return h.client.ReadSector(ctx, offset, length, root, w, h.hk, h.siamuxAddr, h.acc.id, h.accountKey, hpt) + cost, err := h.client.ReadSector(ctx, offset, length, root, w, h.hk, h.siamuxAddr, h.acc.id, h.accountKey, hpt) + if err != nil { + return amount, err + } + return amount.Add(cost), nil }) } -func (h *host) UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) (err error) { +func (h *host) UploadSector(ctx context.Context, sectorRoot types.Hash256, sector *[rhpv2.SectorSize]byte, rev types.FileContractRevision) error { // fetch price table - pt, err := h.priceTable(ctx, nil) - if err != nil { + var pt rhpv3.HostPriceTable + if err := h.acc.WithWithdrawal(ctx, func() (amount types.Currency, err error) { + pt, amount, err = h.priceTable(ctx, nil) + return + }); err != nil { return err } // upload @@ -164,19 +174,24 @@ func (h *host) PriceTableUnpaid(ctx context.Context) (api.HostPriceTable, error) return h.client.PriceTableUnpaid(ctx, h.hk, h.siamuxAddr) } -func (h *host) PriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) { +func (h *host) PriceTable(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, cost types.Currency, err error) { // fetchPT is a helper function that performs the RPC given a payment function fetchPT := func(paymentFn rhp3.PriceTablePaymentFunc) (api.HostPriceTable, error) { return h.client.PriceTable(ctx, h.hk, h.siamuxAddr, paymentFn) } - // pay by contract if a revision is given + // fetch the price table if rev != nil { - return fetchPT(rhp3.PreparePriceTableContractPayment(rev, h.acc.id, h.renterKey)) + hpt, err = fetchPT(rhp3.PreparePriceTableContractPayment(rev, h.acc.id, h.renterKey)) + } else { + hpt, err = fetchPT(rhp3.PreparePriceTableAccountPayment(h.accountKey)) } - // pay by account - return fetchPT(rhp3.PreparePriceTableAccountPayment(h.accountKey)) + // set the cost + if err == nil { + cost = hpt.UpdatePriceTableCost + } + return } // FetchRevision tries to fetch a contract revision from the host. @@ -190,49 +205,67 @@ func (h *host) FetchRevision(ctx context.Context, fetchTimeout time.Duration) (t return h.client.Revision(ctx, h.fcid, h.hk, h.siamuxAddr) } -func (h *host) FundAccount(ctx context.Context, balance types.Currency, rev *types.FileContractRevision) error { +func (h *host) FundAccount(ctx context.Context, desired types.Currency, rev *types.FileContractRevision) error { + log := h.logger.With( + zap.Stringer("host", h.hk), + zap.Stringer("account", h.acc.id), + ) + + // ensure we have at least 2H in the contract to cover the costs + if types.NewCurrency64(2).Cmp(rev.ValidRenterPayout()) >= 0 { + return fmt.Errorf("insufficient funds to fund account: %v <= %v", rev.ValidRenterPayout(), types.NewCurrency64(2)) + } + // fetch current balance - curr, err := h.acc.Balance(ctx) + balance, err := h.acc.Balance(ctx) if err != nil { return err } // return early if we have the desired balance - if curr.Cmp(balance) >= 0 { + if balance.Cmp(desired) >= 0 { return nil } - deposit := balance.Sub(curr) + // calculate the deposit amount + deposit := desired.Sub(balance) return h.acc.WithDeposit(ctx, func() (types.Currency, error) { // fetch pricetable directly to bypass the gouging check - pt, err := h.priceTables.fetch(ctx, h.hk, rev) + pt, _, err := h.priceTables.fetch(ctx, h.hk, rev) if err != nil { return types.ZeroCurrency, err } - // check whether we have money left in the contract + // cap the deposit by what's left in the contract cost := types.NewCurrency64(1) - if cost.Cmp(rev.ValidRenterPayout()) >= 0 { - return types.ZeroCurrency, fmt.Errorf("insufficient funds to fund account: %v <= %v", rev.ValidRenterPayout(), cost) - } availableFunds := rev.ValidRenterPayout().Sub(cost) - - // cap the deposit amount by the money that's left in the contract if deposit.Cmp(availableFunds) > 0 { deposit = availableFunds } + + // fund the account if err := h.client.FundAccount(ctx, rev, h.hk, h.siamuxAddr, deposit, h.acc.id, pt.HostPriceTable, h.renterKey); err != nil { + if rhp3.IsBalanceMaxExceeded(err) { + err = errors.Join(err, h.acc.as.ScheduleSync(ctx, h.acc.id, h.hk)) + } return types.ZeroCurrency, fmt.Errorf("failed to fund account with %v; %w", deposit, err) } + // record the spend h.contractSpendingRecorder.Record(*rev, api.ContractSpending{FundAccount: deposit.Add(cost)}) + + // log the account balance after funding + log.Debugw("fund account succeeded", + "balance", balance.ExactString(), + "deposit", deposit.ExactString(), + ) return deposit, nil }) } func (h *host) SyncAccount(ctx context.Context, rev *types.FileContractRevision) error { // fetch pricetable directly to bypass the gouging check - pt, err := h.priceTables.fetch(ctx, h.hk, rev) + pt, _, err := h.priceTables.fetch(ctx, h.hk, rev) if err != nil { return err } @@ -261,17 +294,17 @@ func (h *host) gougingChecker(ctx context.Context, criticalMigration bool) (goug // priceTable fetches a price table from the host. If a revision is provided, it // will be used to pay for the price table. The returned price table is // guaranteed to be safe to use. -func (h *host) priceTable(ctx context.Context, rev *types.FileContractRevision) (rhpv3.HostPriceTable, error) { - pt, err := h.priceTables.fetch(ctx, h.hk, rev) +func (h *host) priceTable(ctx context.Context, rev *types.FileContractRevision) (rhpv3.HostPriceTable, types.Currency, error) { + pt, cost, err := h.priceTables.fetch(ctx, h.hk, rev) if err != nil { - return rhpv3.HostPriceTable{}, err + return rhpv3.HostPriceTable{}, types.ZeroCurrency, err } gc, err := GougingCheckerFromContext(ctx, false) if err != nil { - return rhpv3.HostPriceTable{}, err + return rhpv3.HostPriceTable{}, cost, err } if breakdown := gc.Check(nil, &pt.HostPriceTable); breakdown.Gouging() { - return rhpv3.HostPriceTable{}, fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown) + return rhpv3.HostPriceTable{}, cost, fmt.Errorf("%w: %v", gouging.ErrPriceTableGouging, breakdown) } - return pt.HostPriceTable, nil + return pt.HostPriceTable, cost, nil } diff --git a/worker/host_test.go b/worker/host_test.go index 6a0a477e8..8bbecaeff 100644 --- a/worker/host_test.go +++ b/worker/host_test.go @@ -111,8 +111,8 @@ func (h *testHost) FetchRevision(ctx context.Context, fetchTimeout time.Duration return rev, nil } -func (h *testHost) PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, error) { - return h.hptFn(), nil +func (h *testHost) PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) { + return h.hptFn(), types.ZeroCurrency, nil } func (h *testHost) PriceTableUnpaid(ctx context.Context) (api.HostPriceTable, error) { diff --git a/worker/pricetables.go b/worker/pricetables.go index c884695b5..fc3901f67 100644 --- a/worker/pricetables.go +++ b/worker/pricetables.go @@ -76,7 +76,7 @@ func newPriceTables(hm HostManager, hs HostStore) *priceTables { } // fetch returns a price table for the given host -func (pts *priceTables) fetch(ctx context.Context, hk types.PublicKey, rev *types.FileContractRevision) (api.HostPriceTable, error) { +func (pts *priceTables) fetch(ctx context.Context, hk types.PublicKey, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) { pts.mu.Lock() pt, exists := pts.priceTables[hk] if !exists { @@ -106,7 +106,7 @@ func (pt *priceTable) ongoingUpdate() (bool, *priceTableUpdate) { return ongoing, pt.update } -func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, err error) { +func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) (hpt api.HostPriceTable, cost types.Currency, err error) { // grab the current price table p.mu.Lock() hpt = p.hpt @@ -116,7 +116,7 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) // current price table is considered to gouge on the block height gc, err := GougingCheckerFromContext(ctx, false) if err != nil { - return api.HostPriceTable{}, err + return api.HostPriceTable{}, types.ZeroCurrency, err } // figure out whether we should update the price table, if not we can return @@ -138,10 +138,10 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) } else if ongoing { select { case <-ctx.Done(): - return api.HostPriceTable{}, fmt.Errorf("%w; %w", errPriceTableUpdateTimedOut, context.Cause(ctx)) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("%w; %w", errPriceTableUpdateTimedOut, context.Cause(ctx)) case <-update.done: } - return update.hpt, update.err + return update.hpt, types.ZeroCurrency, update.err } // this thread is updating the price table @@ -167,12 +167,12 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) // sanity check the host has been scanned before fetching the price table if !host.Scanned { - return api.HostPriceTable{}, fmt.Errorf("host %v was not scanned", p.hk) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("host %v was not scanned", p.hk) } // otherwise fetch it h := p.hm.Host(p.hk, types.FileContractID{}, host.Settings.SiamuxAddr()) - hpt, err = h.PriceTable(ctx, rev) + hpt, cost, err = h.PriceTable(ctx, rev) // record it in the background if shouldRecordPriceTable(err) { @@ -190,7 +190,7 @@ func (p *priceTable) fetch(ctx context.Context, rev *types.FileContractRevision) // handle error after recording if err != nil { - return api.HostPriceTable{}, fmt.Errorf("failed to update pricetable, err %v", err) + return api.HostPriceTable{}, types.ZeroCurrency, fmt.Errorf("failed to update pricetable, err %v", err) } return } diff --git a/worker/pricetables_test.go b/worker/pricetables_test.go index 22c021ccb..ed918fa1a 100644 --- a/worker/pricetables_test.go +++ b/worker/pricetables_test.go @@ -58,14 +58,14 @@ func TestPriceTables(t *testing.T) { // update ctx, cancel := context.WithCancel(gCtx) cancel() - _, err := pts.fetch(ctx, h.hk, nil) + _, _, err := pts.fetch(ctx, h.hk, nil) if !errors.Is(err, errPriceTableUpdateTimedOut) { t.Fatal("expected errPriceTableUpdateTimedOut, got", err) } - // unblock and assert we receive a valid price table + // unblock and assert we paid for the price table close(fetchPTBlockChan) - update, err := pts.fetch(gCtx, h.hk, nil) + update, _, err := pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -75,7 +75,7 @@ func TestPriceTables(t *testing.T) { // refresh the price table on the host, update again, assert we receive the // same price table as it hasn't expired yet h.hi.PriceTable = newTestHostPriceTable() - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -86,7 +86,7 @@ func TestPriceTables(t *testing.T) { pts.priceTables[h.hk].hpt.Expiry = time.Now() // fetch it again and assert we updated the price table - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != h.hi.PriceTable.UID { @@ -97,7 +97,7 @@ func TestPriceTables(t *testing.T) { // the price table since it's not expired validPT = h.hi.PriceTable h.hi.PriceTable = newTestHostPriceTable() - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != validPT.UID { @@ -110,7 +110,7 @@ func TestPriceTables(t *testing.T) { cm.cs.BlockHeight = validPT.HostBlockHeight + uint64(blockHeightLeeway) - priceTableBlockHeightLeeway // fetch it again and assert we updated the price table - update, err = pts.fetch(gCtx, h.hk, nil) + update, _, err = pts.fetch(gCtx, h.hk, nil) if err != nil { t.Fatal(err) } else if update.UID != h.hi.PriceTable.UID { diff --git a/worker/worker.go b/worker/worker.go index da2710cbc..7073e0c63 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -96,7 +96,6 @@ type ( LockAccount(ctx context.Context, id rhpv3.Account, hostKey types.PublicKey, exclusive bool, duration time.Duration) (api.Account, uint64, error) UnlockAccount(ctx context.Context, id rhpv3.Account, lockID uint64) error - ResetDrift(ctx context.Context, id rhpv3.Account) error SetBalance(ctx context.Context, id rhpv3.Account, hk types.PublicKey, amt *big.Int) error ScheduleSync(ctx context.Context, id rhpv3.Account, hk types.PublicKey) error } @@ -637,7 +636,7 @@ func (w *Worker) rhpRenewHandler(jc jape.Context) { var renewed rhpv2.ContractRevision var txnSet []types.Transaction var contractPrice, fundAmount types.Currency - if jc.Check("couldn't renew contract", w.withRevision(ctx, defaultRevisionFetchTimeout, rrr.ContractID, rrr.HostKey, rrr.SiamuxAddr, lockingPriorityRenew, func(_ types.FileContractRevision) (err error) { + if jc.Check("couldn't renew contract", w.withContractLock(ctx, rrr.ContractID, lockingPriorityRenew, func() (err error) { h := w.Host(rrr.HostKey, rrr.ContractID, rrr.SiamuxAddr) renewed, txnSet, contractPrice, fundAmount, err = h.RenewContract(ctx, rrr) return err @@ -678,24 +677,8 @@ func (w *Worker) rhpFundHandler(jc jape.Context) { ctx = WithGougingChecker(ctx, w.bus, gp) // fund the account - jc.Check("couldn't fund account", w.withRevision(ctx, defaultRevisionFetchTimeout, rfr.ContractID, rfr.HostKey, rfr.SiamuxAddr, lockingPriorityFunding, func(rev types.FileContractRevision) (err error) { - h := w.Host(rfr.HostKey, rev.ParentID, rfr.SiamuxAddr) - err = h.FundAccount(ctx, rfr.Balance, &rev) - if rhp3.IsBalanceMaxExceeded(err) { - // sync the account - err = h.SyncAccount(ctx, &rev) - if err != nil { - w.logger.Infof(fmt.Sprintf("failed to sync account: %v", err), "host", rfr.HostKey) - return - } - - // try funding the account again - err = h.FundAccount(ctx, rfr.Balance, &rev) - if err != nil { - w.logger.Errorw(fmt.Sprintf("failed to fund account after syncing: %v", err), "host", rfr.HostKey, "balance", rfr.Balance) - } - } - return + jc.Check("couldn't fund account", w.withRevision(ctx, defaultRevisionFetchTimeout, rfr.ContractID, rfr.HostKey, rfr.SiamuxAddr, lockingPriorityFunding, func(rev types.FileContractRevision) error { + return w.Host(rfr.HostKey, rev.ParentID, rfr.SiamuxAddr).FundAccount(ctx, rfr.Balance, &rev) })) }