Skip to content

Commit

Permalink
Use WithContext to respect the contexts passed down to the SQL store (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl authored Mar 15, 2024
2 parents b54a0c0 + 82a740f commit 7efeb63
Show file tree
Hide file tree
Showing 39 changed files with 242 additions and 276 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ linters:
- typecheck
- whitespace
- tagliatelle
- unused
- unparam
- deadcode

issues:
# Maximum issues count per one linter. Set to 0 to disable. Default is 50.
Expand Down
12 changes: 7 additions & 5 deletions alerts/alerts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,29 @@ type testWebhookStore struct {
listed int
}

func (s *testWebhookStore) DeleteWebhook(wb webhooks.Webhook) error {
func (s *testWebhookStore) DeleteWebhook(_ context.Context, wb webhooks.Webhook) error {
s.mu.Lock()
defer s.mu.Unlock()
s.deleted++
return nil
}

func (s *testWebhookStore) AddWebhook(wb webhooks.Webhook) error {
func (s *testWebhookStore) AddWebhook(_ context.Context, wb webhooks.Webhook) error {
s.mu.Lock()
defer s.mu.Unlock()
s.added++
return nil
}

func (s *testWebhookStore) Webhooks() ([]webhooks.Webhook, error) {
func (s *testWebhookStore) Webhooks(_ context.Context) ([]webhooks.Webhook, error) {
s.mu.Lock()
defer s.mu.Unlock()
s.listed++
return nil, nil
}

var _ webhooks.WebhookStore = (*testWebhookStore)(nil)

func TestWebhooks(t *testing.T) {
store := &testWebhookStore{}
mgr, err := webhooks.NewManager(zap.NewNop().Sugar(), store)
Expand Down Expand Up @@ -75,7 +77,7 @@ func TestWebhooks(t *testing.T) {
if hookID := wh.String(); hookID != fmt.Sprintf("%v.%v.%v", wh.URL, wh.Module, "") {
t.Fatalf("wrong result for wh.String(): %v != %v", wh.String(), hookID)
}
err = mgr.Register(wh)
err = mgr.Register(context.Background(), wh)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -110,7 +112,7 @@ func TestWebhooks(t *testing.T) {
}

// unregister hook
if err := mgr.Delete(webhooks.Webhook{
if err := mgr.Delete(context.Background(), webhooks.Webhook{
Event: hooks[0].Event,
Module: hooks[0].Module,
URL: hooks[0].URL,
Expand Down
4 changes: 2 additions & 2 deletions autopilot/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (a *accounts) refillWorkerAccounts(ctx context.Context, w Worker) {
go func(contract api.ContractMetadata) {
rCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, workerID, contract)
accountID, refilled, rerr := refillWorkerAccount(rCtx, a.a, w, contract)
if rerr != nil {
if rerr.Is(errMaxDriftExceeded) {
// register the alert if error is errMaxDriftExceeded
Expand Down Expand Up @@ -184,7 +184,7 @@ func (err *refillError) Is(target error) bool {
return errors.Is(err.err, target)
}

func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, workerID string, contract api.ContractMetadata) (accountID rhpv3.Account, refilled bool, rerr *refillError) {
func refillWorkerAccount(ctx context.Context, a AccountStore, w Worker, contract api.ContractMetadata) (accountID rhpv3.Account, refilled bool, rerr *refillError) {
wrapErr := func(err error, keysAndValues ...interface{}) *refillError {
if err == nil {
return nil
Expand Down
4 changes: 2 additions & 2 deletions autopilot/autopilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ func (ap *Autopilot) Run() error {
}

// migration
ap.m.tryPerformMigrations(ap.shutdownCtx, ap.workers)
ap.m.tryPerformMigrations(ap.workers)

// pruning
if ap.state.cfg.Contracts.Prune {
ap.c.tryPerformPruning(ap.shutdownCtx, ap.workers)
ap.c.tryPerformPruning(ap.workers)
} else {
ap.logger.Debug("pruning disabled")
}
Expand Down
13 changes: 5 additions & 8 deletions autopilot/contract_spending.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (c *contractor) contractSpending(ctx context.Context, contract api.Contract
return total, nil
}

func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeriod uint64) (types.Currency, error) {
func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeriod uint64) types.Currency {
totalCosts := make(map[types.FileContractID]types.Currency)
for _, c := range contracts {
totalCosts[c.ID] = c.TotalCost
Expand All @@ -41,22 +41,19 @@ func (c *contractor) currentPeriodSpending(contracts []api.Contract, currentPeri
for _, contract := range filtered {
totalAllocated = totalAllocated.Add(contract.TotalCost)
}
return totalAllocated, nil
return totalAllocated
}

func (c *contractor) remainingFunds(contracts []api.Contract) (types.Currency, error) {
func (c *contractor) remainingFunds(contracts []api.Contract) types.Currency {
state := c.ap.State()

// find out how much we spent in the current period
spent, err := c.currentPeriodSpending(contracts, state.period)
if err != nil {
return types.ZeroCurrency, err
}
spent := c.currentPeriodSpending(contracts, state.period)

// figure out remaining funds
var remaining types.Currency
if state.cfg.Contracts.Allowance.Cmp(spent) > 0 {
remaining = state.cfg.Contracts.Allowance.Sub(spent)
}
return remaining, nil
return remaining
}
21 changes: 7 additions & 14 deletions autopilot/contractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ func (c *contractor) performContractMaintenance(ctx context.Context, w Worker) (
// min score to pass checks
var minScore float64
if len(hosts) > 0 {
minScore = c.calculateMinScore(ctx, candidates, state.cfg.Contracts.Amount)
minScore = c.calculateMinScore(candidates, state.cfg.Contracts.Amount)
} else {
c.logger.Warn("could not calculate min score, no hosts found")
}
Expand Down Expand Up @@ -324,10 +324,7 @@ func (c *contractor) performContractMaintenance(ctx context.Context, w Worker) (
}

// calculate remaining funds
remaining, err := c.remainingFunds(contracts)
if err != nil {
return false, err
}
remaining := c.remainingFunds(contracts)

// calculate 'limit' amount of contracts we want to renew
var limit int
Expand Down Expand Up @@ -1140,7 +1137,7 @@ func (c *contractor) initialContractFunding(settings rhpv2.HostSettings, txnFee,
return funding
}

func (c *contractor) refreshFundingEstimate(ctx context.Context, cfg api.AutopilotConfig, ci contractInfo, fee types.Currency) (types.Currency, error) {
func (c *contractor) refreshFundingEstimate(cfg api.AutopilotConfig, ci contractInfo, fee types.Currency) types.Currency {
// refresh with 1.2x the funds
refreshAmount := ci.contract.TotalCost.Mul64(6).Div64(5)

Expand All @@ -1159,7 +1156,7 @@ func (c *contractor) refreshFundingEstimate(ctx context.Context, cfg api.Autopil
"fcid", ci.contract.ID,
"refreshAmount", refreshAmount,
"refreshAmountCapped", refreshAmountCapped)
return refreshAmountCapped, nil
return refreshAmountCapped
}

func (c *contractor) renewFundingEstimate(ctx context.Context, ci contractInfo, fee types.Currency, renewing bool) (types.Currency, error) {
Expand Down Expand Up @@ -1249,7 +1246,7 @@ func (c *contractor) renewFundingEstimate(ctx context.Context, ci contractInfo,
return cappedEstimatedCost, nil
}

func (c *contractor) calculateMinScore(ctx context.Context, candidates []scoredHost, numContracts uint64) float64 {
func (c *contractor) calculateMinScore(candidates []scoredHost, numContracts uint64) float64 {
// return early if there's no hosts
if len(candidates) == 0 {
c.logger.Warn("min host score is set to the smallest non-zero float because there are no candidate hosts")
Expand Down Expand Up @@ -1475,11 +1472,7 @@ func (c *contractor) refreshContract(ctx context.Context, w Worker, ci contractI
// calculate the renter funds
var renterFunds types.Currency
if isOutOfFunds(state.cfg, ci.priceTable, ci.contract) {
renterFunds, err = c.refreshFundingEstimate(ctx, state.cfg, ci, state.fee)
if err != nil {
c.logger.Errorw(fmt.Sprintf("could not get refresh funding estimate, err: %v", err), "hk", hk, "fcid", fcid)
return api.ContractMetadata{}, true, err
}
renterFunds = c.refreshFundingEstimate(state.cfg, ci, state.fee)
} else {
renterFunds = rev.ValidRenterPayout() // don't increase funds
}
Expand Down Expand Up @@ -1599,7 +1592,7 @@ func (c *contractor) formContract(ctx context.Context, w Worker, host hostdb.Hos
return formedContract, true, nil
}

func (c *contractor) tryPerformPruning(ctx context.Context, wp *workerPool) {
func (c *contractor) tryPerformPruning(wp *workerPool) {
c.mu.Lock()
if c.pruning || c.ap.isStopped() {
c.mu.Unlock()
Expand Down
7 changes: 3 additions & 4 deletions autopilot/contractor_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package autopilot

import (
"context"
"math"
"testing"

Expand All @@ -19,19 +18,19 @@ func TestCalculateMinScore(t *testing.T) {
}

// Test with 100 hosts which makes for a random set size of 250
minScore := c.calculateMinScore(context.Background(), candidates, 100)
minScore := c.calculateMinScore(candidates, 100)
if minScore != 0.002 {
t.Fatalf("expected minScore to be 0.002 but was %v", minScore)
}

// Test with 0 hosts
minScore = c.calculateMinScore(context.Background(), []scoredHost{}, 100)
minScore = c.calculateMinScore([]scoredHost{}, 100)
if minScore != math.SmallestNonzeroFloat64 {
t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore)
}

// Test with 300 hosts which is 50 more than we have
minScore = c.calculateMinScore(context.Background(), candidates, 300)
minScore = c.calculateMinScore(candidates, 300)
if minScore != math.SmallestNonzeroFloat64 {
t.Fatalf("expected minScore to be math.SmallestNonzeroFLoat64 but was %v", minScore)
}
Expand Down
8 changes: 4 additions & 4 deletions autopilot/hostscore.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func hostScore(cfg api.AutopilotConfig, h hostdb.Host, storedData uint64, expect
Collateral: collateralScore(cfg, h.PriceTable.HostPriceTable, uint64(allocationPerHost)),
Interactions: interactionScore(h),
Prices: priceAdjustmentScore(hostPeriodCost, cfg),
StorageRemaining: storageRemainingScore(cfg, h.Settings, storedData, expectedRedundancy, allocationPerHost),
StorageRemaining: storageRemainingScore(h.Settings, storedData, allocationPerHost),
Uptime: uptimeScore(h),
Version: versionScore(h.Settings),
}
Expand Down Expand Up @@ -74,7 +74,7 @@ func priceAdjustmentScore(hostCostPerPeriod types.Currency, cfg api.AutopilotCon
panic("unreachable")
}

func storageRemainingScore(cfg api.AutopilotConfig, h rhpv2.HostSettings, storedData uint64, expectedRedundancy, allocationPerHost float64) float64 {
func storageRemainingScore(h rhpv2.HostSettings, storedData uint64, allocationPerHost float64) float64 {
// hostExpectedStorage is the amount of storage that we expect to be able to
// store on this host overall, which should include the stored data that is
// already on the host.
Expand Down Expand Up @@ -291,7 +291,7 @@ func uploadCostForScore(cfg api.AutopilotConfig, h hostdb.Host, bytes uint64) ty
return uploadSectorCostRHPv3.Mul64(numSectors)
}

func downloadCostForScore(cfg api.AutopilotConfig, h hostdb.Host, bytes uint64) types.Currency {
func downloadCostForScore(h hostdb.Host, bytes uint64) types.Currency {
rsc := h.PriceTable.BaseCost().Add(h.PriceTable.ReadSectorCost(rhpv2.SectorSize))
downloadSectorCostRHPv3, _ := rsc.Total()
numSectors := bytesToSectors(bytes)
Expand All @@ -314,7 +314,7 @@ func hostPeriodCostForScore(h hostdb.Host, cfg api.AutopilotConfig, expectedRedu
hostCollateral := rhpv2.ContractFormationCollateral(cfg.Contracts.Period, storagePerHost, h.Settings)
hostContractPrice := contractPriceForScore(h)
hostUploadCost := uploadCostForScore(cfg, h, uploadPerHost)
hostDownloadCost := downloadCostForScore(cfg, h, downloadPerHost)
hostDownloadCost := downloadCostForScore(h, downloadPerHost)
hostStorageCost := storageCostForScore(cfg, h, storagePerHost)
siafundFee := hostCollateral.
Add(hostContractPrice).
Expand Down
2 changes: 1 addition & 1 deletion autopilot/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (m *migrator) slabMigrationEstimate(remaining int) time.Duration {
return time.Duration(totalNumMS) * time.Millisecond
}

func (m *migrator) tryPerformMigrations(ctx context.Context, wp *workerPool) {
func (m *migrator) tryPerformMigrations(wp *workerPool) {
m.mu.Lock()
if m.migrating || m.ap.isStopped() {
m.mu.Unlock()
Expand Down
8 changes: 4 additions & 4 deletions autopilot/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ func (s *scanner) isInterrupted() bool {
}
}

func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bool) bool {
func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bool) {
if s.ap.isStopped() {
return false
return
}

scanType := "host scan"
Expand All @@ -185,7 +185,7 @@ func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bo
s.interruptScanChan = make(chan struct{})
} else if s.scanning || !s.isScanRequired() {
s.mu.Unlock()
return false
return
}
s.scanningLastStart = time.Now()
s.scanning = true
Expand Down Expand Up @@ -229,7 +229,7 @@ func (s *scanner) tryPerformHostScan(ctx context.Context, w scanWorker, force bo
s.logger.Debugf("%s finished after %v", st, time.Since(s.scanningLastStart))
s.mu.Unlock()
}(scanType)
return true
return
}

func (s *scanner) tryUpdateTimeout() {
Expand Down
4 changes: 2 additions & 2 deletions autopilot/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestScanner(t *testing.T) {
// init new scanner
b := &mockBus{hosts: hosts}
w := &mockWorker{blockChan: make(chan struct{})}
s := newTestScanner(b, w)
s := newTestScanner(b)

// assert it started a host scan
s.tryPerformHostScan(context.Background(), w, false)
Expand Down Expand Up @@ -139,7 +139,7 @@ func (s *scanner) isScanning() bool {
return s.scanning
}

func newTestScanner(b *mockBus, w *mockWorker) *scanner {
func newTestScanner(b *mockBus) *scanner {
ap := &Autopilot{}
ap.shutdownCtx, ap.shutdownCtxCancel = context.WithCancel(context.Background())
return &scanner{
Expand Down
8 changes: 4 additions & 4 deletions bus/bus.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ type (
EphemeralAccountStore interface {
Accounts(context.Context) ([]api.Account, error)
SaveAccounts(context.Context, []api.Account) error
SetUncleanShutdown() error
SetUncleanShutdown(context.Context) error
}

MetricsStore interface {
Expand Down Expand Up @@ -2022,7 +2022,7 @@ func (b *bus) webhookHandlerDelete(jc jape.Context) {
if jc.Decode(&wh) != nil {
return
}
err := b.hooks.Delete(wh)
err := b.hooks.Delete(jc.Request.Context(), wh)
if errors.Is(err, webhooks.ErrWebhookNotFound) {
jc.Error(fmt.Errorf("webhook for URL %v and event %v.%v not found", wh.URL, wh.Module, wh.Event), http.StatusNotFound)
return
Expand All @@ -2044,7 +2044,7 @@ func (b *bus) webhookHandlerPost(jc jape.Context) {
if jc.Decode(&req) != nil {
return
}
err := b.hooks.Register(webhooks.Webhook{
err := b.hooks.Register(jc.Request.Context(), webhooks.Webhook{
Event: req.Event,
Module: req.Module,
URL: req.URL,
Expand Down Expand Up @@ -2412,7 +2412,7 @@ func New(s Syncer, am *alerts.Manager, hm *webhooks.Manager, cm ChainManager, tp

// mark the shutdown as unclean, this will be overwritten when/if the
// accounts are saved on shutdown
if err := eas.SetUncleanShutdown(); err != nil {
if err := eas.SetUncleanShutdown(ctx); err != nil {
return nil, fmt.Errorf("failed to mark account shutdown as unclean: %w", err)
}
return b, nil
Expand Down
2 changes: 1 addition & 1 deletion bus/client/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (c *Client) SearchObjects(ctx context.Context, bucket string, opts api.Sear
}

func (c *Client) renameObjects(ctx context.Context, bucket, from, to, mode string, force bool) (err error) {
err = c.c.POST("/objects/rename", api.ObjectsRenameRequest{
err = c.c.WithContext(ctx).POST("/objects/rename", api.ObjectsRenameRequest{
Bucket: bucket,
Force: force,
From: from,
Expand Down
Loading

0 comments on commit 7efeb63

Please sign in to comment.