Skip to content

Commit

Permalink
internal: move pricetable implementations to internal
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl committed Dec 2, 2024
1 parent 03266c9 commit bc16c75
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 63 deletions.
12 changes: 6 additions & 6 deletions worker/prices.go → internal/prices/prices.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package worker
package prices

import (
"context"
Expand All @@ -24,7 +24,7 @@ type (
PublicKey() types.PublicKey
}

pricesCache struct {
PricesCache struct {
mu sync.Mutex
cache map[types.PublicKey]*cachedPrices
}
Expand All @@ -45,14 +45,14 @@ type (
}
)

func newPricesCache() *pricesCache {
return &pricesCache{
func NewPricesCache() *PricesCache {
return &PricesCache{
cache: make(map[types.PublicKey]*cachedPrices),
}
}

// fetch returns a price table for the given host
func (c *pricesCache) fetch(ctx context.Context, h PricesFetcher) (rhpv4.HostPrices, error) {
// Fetch returns a price table for the given host
func (c *PricesCache) Fetch(ctx context.Context, h PricesFetcher) (rhpv4.HostPrices, error) {
c.mu.Lock()
prices, exists := c.cache[h.PublicKey()]
if !exists {
Expand Down
68 changes: 49 additions & 19 deletions worker/prices_test.go → internal/prices/prices_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package worker
package prices

import (
"context"
Expand All @@ -10,13 +10,41 @@ import (
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/internal/test/mocks"
"lukechampine.com/frand"
)

type pricesFetcher struct {
hk types.PublicKey
hptFn func() api.HostPriceTable
pFn func() rhpv4.HostPrices
}

func (pf *pricesFetcher) Prices(ctx context.Context) (rhpv4.HostPrices, error) {
return pf.pFn(), nil
}

func (pf *pricesFetcher) PriceTable(ctx context.Context, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) {
return pf.hptFn(), types.ZeroCurrency, nil
}

func (pf *pricesFetcher) PublicKey() types.PublicKey {
return pf.hk
}

func newTestHostPrices() rhpv4.HostPrices {
var sig types.Signature
frand.Read(sig[:])

return rhpv4.HostPrices{
TipHeight: 100,
ValidUntil: time.Now().Add(time.Minute),
Signature: sig,
}
}

func TestPricesCache(t *testing.T) {
cache := newPricesCache()
hk := types.PublicKey{1}
hostMock := mocks.NewHost(hk)
c := mocks.NewContract(hk, types.FileContractID{1})
cache := NewPricesCache()
hostMock := mocks.NewHost(types.PublicKey{1})

// expire its prices
expiredPT := newTestHostPriceTable()
Expand All @@ -26,31 +54,34 @@ func TestPricesCache(t *testing.T) {
// manage the host, make sure fetching the prices blocks
fetchPTBlockChan := make(chan struct{})
validPrices := newTestHostPrices()
h := newTestHostCustom(hostMock, c, func() api.HostPriceTable {
t.Fatal("shouldn't be called")
return api.HostPriceTable{}
}, func() rhpv4.HostPrices {
<-fetchPTBlockChan
return validPrices
})

h := &pricesFetcher{
hk: types.PublicKey{1},
hptFn: func() api.HostPriceTable {
t.Fatal("shouldn't be called")
return api.HostPriceTable{}
},
pFn: func() rhpv4.HostPrices {
<-fetchPTBlockChan
return validPrices
},
}
// trigger a fetch to make it block
go cache.fetch(context.Background(), h)
go cache.Fetch(context.Background(), h)
time.Sleep(50 * time.Millisecond)

// fetch it again but with a canceled context to avoid blocking
// indefinitely, the error will indicate we were blocking on a prices
// update
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := cache.fetch(ctx, h)
_, err := cache.Fetch(ctx, h)
if !errors.Is(err, errPriceTableUpdateTimedOut) {
t.Fatal("expected errPriceTableUpdateTimedOut, got", err)
}

// unblock and assert we paid for the prices
close(fetchPTBlockChan)
update, err := cache.fetch(context.Background(), h)
update, err := cache.Fetch(context.Background(), h)
if err != nil {
t.Fatal(err)
} else if update.Signature != validPrices.Signature {
Expand All @@ -61,8 +92,7 @@ func TestPricesCache(t *testing.T) {
// same prices as it hasn't expired yet
oldValidPrices := validPrices
validPrices = newTestHostPrices()
h.UpdatePrices(validPrices)
update, err = cache.fetch(context.Background(), h)
update, err = cache.Fetch(context.Background(), h)
if err != nil {
t.Fatal(err)
} else if update.Signature != oldValidPrices.Signature {
Expand All @@ -71,7 +101,7 @@ func TestPricesCache(t *testing.T) {

// manually expire the prices
cache.cache[h.PublicKey()].renewTime = time.Now().Add(-time.Second)
update, err = cache.fetch(context.Background(), h)
update, err = cache.Fetch(context.Background(), h)
if err != nil {
t.Fatal(err)
} else if update.Signature != validPrices.Signature {
Expand Down
17 changes: 6 additions & 11 deletions worker/pricetables.go → internal/prices/pricetables.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package worker
package prices

import (
"context"
Expand All @@ -16,11 +16,6 @@ const (
// for use, we essentially add 30 seconds to the current time when checking
// whether we are still before a pricetable's expiry time
priceTableValidityLeeway = 30 * time.Second

// priceTableBlockHeightLeeway is the amount of blocks before a price table
// is considered gouging on the block height when we renew it even if it is
// still valid
priceTableBlockHeightLeeway = 2
)

var (
Expand All @@ -35,7 +30,7 @@ type (
PublicKey() types.PublicKey
}

priceTables struct {
PriceTables struct {
mu sync.Mutex
priceTables map[types.PublicKey]*priceTable
}
Expand All @@ -56,14 +51,14 @@ type (
}
)

func newPriceTables() *priceTables {
return &priceTables{
func NewPriceTables() *PriceTables {
return &PriceTables{
priceTables: make(map[types.PublicKey]*priceTable),
}
}

// fetch returns a price table for the given host
func (pts *priceTables) fetch(ctx context.Context, h priceTableFetcher, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) {
// Fetch returns a price table for the given host
func (pts *PriceTables) Fetch(ctx context.Context, h priceTableFetcher, rev *types.FileContractRevision) (api.HostPriceTable, types.Currency, error) {
pts.mu.Lock()
pt, exists := pts.priceTables[h.PublicKey()]
if !exists {
Expand Down
46 changes: 30 additions & 16 deletions worker/pricetables_test.go → internal/prices/pricetables_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
package worker
package prices

import (
"context"
"errors"
"testing"
"time"

rhpv3 "go.sia.tech/core/rhp/v3"
rhpv4 "go.sia.tech/core/rhp/v4"
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/internal/test/mocks"
"lukechampine.com/frand"
)

func newTestHostPriceTable() api.HostPriceTable {
var uid rhpv3.SettingsID
frand.Read(uid[:])

return api.HostPriceTable{
HostPriceTable: rhpv3.HostPriceTable{UID: uid, HostBlockHeight: 100, Validity: time.Minute},
Expiry: time.Now().Add(time.Minute),
}
}

func TestPriceTables(t *testing.T) {
// create host manager & price table
pts := newPriceTables()
pts := NewPriceTables()

// create host & contract mock
hostMock := mocks.NewHost(types.PublicKey{1})
c := mocks.NewContract(hostMock.PublicKey(), types.FileContractID{1})

// expire its price table
expiredPT := newTestHostPriceTable()
Expand All @@ -28,31 +39,35 @@ func TestPriceTables(t *testing.T) {
// manage the host, make sure fetching the price table blocks
fetchPTBlockChan := make(chan struct{})
validPT := newTestHostPriceTable()
h := newTestHostCustom(hostMock, c, func() api.HostPriceTable {
<-fetchPTBlockChan
return validPT
}, func() rhpv4.HostPrices {
t.Fatal("shouldn't be called")
return rhpv4.HostPrices{}
})
h := &pricesFetcher{
hk: types.PublicKey{1},
hptFn: func() api.HostPriceTable {
<-fetchPTBlockChan
return validPT
},
pFn: func() rhpv4.HostPrices {
t.Fatal("shouldn't be called")
return rhpv4.HostPrices{}
},
}

// trigger a fetch to make it block
go pts.fetch(context.Background(), h, nil)
go pts.Fetch(context.Background(), h, nil)
time.Sleep(50 * time.Millisecond)

// fetch it again but with a canceled context to avoid blocking
// indefinitely, the error will indicate we were blocking on a price table
// update
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := pts.fetch(ctx, h, nil)
_, _, err := pts.Fetch(ctx, h, nil)
if !errors.Is(err, errPriceTableUpdateTimedOut) {
t.Fatal("expected errPriceTableUpdateTimedOut, got", err)
}

// unblock and assert we paid for the price table
close(fetchPTBlockChan)
update, _, err := pts.fetch(context.Background(), h, nil)
update, _, err := pts.Fetch(context.Background(), h, nil)
if err != nil {
t.Fatal(err)
} else if update.UID != validPT.UID {
Expand All @@ -63,8 +78,7 @@ func TestPriceTables(t *testing.T) {
// same price table as it hasn't expired yet
oldValidPT := validPT
validPT = newTestHostPriceTable()
h.UpdatePriceTable(validPT)
update, _, err = pts.fetch(context.Background(), h, nil)
update, _, err = pts.Fetch(context.Background(), h, nil)
if err != nil {
t.Fatal(err)
} else if update.UID != oldValidPT.UID {
Expand All @@ -73,7 +87,7 @@ func TestPriceTables(t *testing.T) {

// manually expire the price table
pts.priceTables[h.PublicKey()].renewTime = time.Now().Add(-time.Second)
update, _, err = pts.fetch(context.Background(), h, nil)
update, _, err = pts.Fetch(context.Background(), h, nil)
if err != nil {
t.Fatal(err)
} else if update.UID != validPT.UID {
Expand Down
15 changes: 8 additions & 7 deletions worker/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/internal/gouging"
"go.sia.tech/renterd/internal/host"
"go.sia.tech/renterd/internal/prices"
rhp3 "go.sia.tech/renterd/internal/rhp/v3"
"go.sia.tech/renterd/internal/worker"
"go.uber.org/zap"
Expand All @@ -28,15 +29,15 @@ type (
client *rhp3.Client
contractSpendingRecorder ContractSpendingRecorder
logger *zap.SugaredLogger
priceTables *priceTables
priceTables *prices.PriceTables
}

hostDownloadClient struct {
hk types.PublicKey
siamuxAddr string

acc *worker.Account
pts *priceTables
pts *prices.PriceTables
rhp3 *rhp3.Client
}
)
Expand Down Expand Up @@ -74,7 +75,7 @@ func (w *Worker) Downloader(hk types.PublicKey, siamuxAddr string) host.Download
func (h *hostClient) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) (err error) {
var amount types.Currency
return h.acc.WithWithdrawal(func() (types.Currency, error) {
pt, uptc, err := h.priceTables.fetch(ctx, h, nil)
pt, uptc, err := h.priceTables.Fetch(ctx, h, nil)
if err != nil {
return types.ZeroCurrency, err
}
Expand Down Expand Up @@ -175,7 +176,7 @@ func (h *hostClient) FundAccount(ctx context.Context, desired types.Currency, re
deposit := desired.Sub(balance)

// fetch pricetable directly to bypass the gouging check
pt, _, err := h.priceTables.fetch(ctx, h, rev)
pt, _, err := h.priceTables.Fetch(ctx, h, rev)
if err != nil {
return types.ZeroCurrency, err
}
Expand Down Expand Up @@ -209,7 +210,7 @@ func (h *hostClient) FundAccount(ctx context.Context, desired types.Currency, re

func (h *hostClient) SyncAccount(ctx context.Context, rev *types.FileContractRevision) error {
// fetch pricetable directly to bypass the gouging check
pt, _, err := h.priceTables.fetch(ctx, h, rev)
pt, _, err := h.priceTables.Fetch(ctx, h, rev)
if err != nil {
return err
}
Expand All @@ -228,7 +229,7 @@ func (h *hostClient) SyncAccount(ctx context.Context, rev *types.FileContractRev
// will be used to pay for the price table. The returned price table is
// guaranteed to be safe to use.
func (h *hostClient) priceTable(ctx context.Context, rev *types.FileContractRevision) (rhpv3.HostPriceTable, types.Currency, error) {
pt, cost, err := h.priceTables.fetch(ctx, h, rev)
pt, cost, err := h.priceTables.Fetch(ctx, h, rev)
if err != nil {
return rhpv3.HostPriceTable{}, types.ZeroCurrency, err
}
Expand All @@ -244,7 +245,7 @@ func (h *hostClient) priceTable(ctx context.Context, rev *types.FileContractRevi

func (d *hostDownloadClient) DownloadSector(ctx context.Context, w io.Writer, root types.Hash256, offset, length uint32, overpay bool) (err error) {
return d.acc.WithWithdrawal(func() (types.Currency, error) {
pt, ptc, err := d.pts.fetch(ctx, d, nil)
pt, ptc, err := d.pts.Fetch(ctx, d, nil)
if err != nil {
return types.ZeroCurrency, err
}
Expand Down
Loading

0 comments on commit bc16c75

Please sign in to comment.