Skip to content

Commit

Permalink
Fix pricetable from host scan overwriting valid price table (#1347)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl authored Jun 27, 2024
1 parent e12dc6f commit 74ed4bf
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 16 deletions.
46 changes: 33 additions & 13 deletions stores/hostdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/google/go-cmp/cmp"
"gitlab.com/NebulousLabs/encoding"
rhpv2 "go.sia.tech/core/rhp/v2"
rhpv3 "go.sia.tech/core/rhp/v3"
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/hostdb"
Expand Down Expand Up @@ -496,10 +497,23 @@ func TestRecordScan(t *testing.T) {
// Record a scan.
firstScanTime := time.Now().UTC()
settings := rhpv2.HostSettings{NetAddress: "host.com"}
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, firstScanTime, settings, true)}); err != nil {
pt := rhpv3.HostPriceTable{
HostBlockHeight: 123,
}
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, firstScanTime, settings, pt, true)}); err != nil {
t.Fatal(err)
}
host, err = ss.Host(ctx, hk)
if err != nil {
t.Fatal(err)
} else if time.Now().Before(host.PriceTable.Expiry) {
t.Fatal("invalid expiry")
} else if host.PriceTable.HostBlockHeight != pt.HostBlockHeight {
t.Fatalf("mismatch %v %v", host.PriceTable.HostBlockHeight, pt.HostBlockHeight)
}

// Update the price table expiry to be in the future.
_, err = ss.DB().Exec(ctx, "UPDATE hosts SET price_table_expiry = ? WHERE public_key = ?", time.Now().Add(time.Hour), sql.PublicKey(hk))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -529,15 +543,19 @@ func TestRecordScan(t *testing.T) {

// Record another scan 1 hour after the previous one.
secondScanTime := firstScanTime.Add(time.Hour)
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, secondScanTime, settings, true)}); err != nil {
pt.HostBlockHeight = 456
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, secondScanTime, settings, pt, true)}); err != nil {
t.Fatal(err)
}
host, err = ss.Host(ctx, hk)
if err != nil {
t.Fatal(err)
}
if host.Interactions.LastScan.UnixNano() != secondScanTime.UnixNano() {
} else if host.Interactions.LastScan.UnixNano() != secondScanTime.UnixNano() {
t.Fatal("wrong time")
} else if time.Now().After(host.PriceTable.Expiry) {
t.Fatal("invalid expiry")
} else if host.PriceTable.HostBlockHeight != 123 {
t.Fatal("price table was updated")
}
host.Interactions.LastScan = time.Time{}
uptime += secondScanTime.Sub(firstScanTime)
Expand All @@ -556,7 +574,7 @@ func TestRecordScan(t *testing.T) {

// Record another scan 2 hours after the second one. This time it fails.
thirdScanTime := secondScanTime.Add(2 * time.Hour)
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, thirdScanTime, settings, false)}); err != nil {
if err := ss.RecordHostScans(ctx, []api.HostScan{newTestScan(hk, thirdScanTime, settings, pt, false)}); err != nil {
t.Fatal(err)
}
host, err = ss.Host(ctx, hk)
Expand Down Expand Up @@ -612,10 +630,11 @@ func TestRemoveHosts(t *testing.T) {
}

now := time.Now().UTC()
pt := rhpv3.HostPriceTable{}
t1 := now.Add(-time.Minute * 120) // 2 hours ago
t2 := now.Add(-time.Minute * 90) // 1.5 hours ago (30min downtime)
hi1 := newTestScan(hk, t1, rhpv2.HostSettings{NetAddress: "host.com"}, false)
hi2 := newTestScan(hk, t2, rhpv2.HostSettings{NetAddress: "host.com"}, false)
hi1 := newTestScan(hk, t1, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false)
hi2 := newTestScan(hk, t2, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false)

// record interactions
if err := ss.RecordHostScans(context.Background(), []api.HostScan{hi1, hi2}); err != nil {
Expand Down Expand Up @@ -645,7 +664,7 @@ func TestRemoveHosts(t *testing.T) {

// record interactions
t3 := now.Add(-time.Minute * 60) // 1 hour ago (60min downtime)
hi3 := newTestScan(hk, t3, rhpv2.HostSettings{NetAddress: "host.com"}, false)
hi3 := newTestScan(hk, t3, rhpv2.HostSettings{NetAddress: "host.com"}, pt, false)
if err := ss.RecordHostScans(context.Background(), []api.HostScan{hi3}); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1300,12 +1319,13 @@ func hostByPubKey(tx *gorm.DB, hostKey types.PublicKey) (dbHost, error) {
}

// newTestScan returns a host interaction with given parameters.
func newTestScan(hk types.PublicKey, scanTime time.Time, settings rhpv2.HostSettings, success bool) api.HostScan {
func newTestScan(hk types.PublicKey, scanTime time.Time, settings rhpv2.HostSettings, pt rhpv3.HostPriceTable, success bool) api.HostScan {
return api.HostScan{
HostKey: hk,
Success: success,
Timestamp: scanTime,
Settings: settings,
HostKey: hk,
Success: success,
Timestamp: scanTime,
Settings: settings,
PriceTable: pt,
}
}

Expand Down
6 changes: 3 additions & 3 deletions stores/sql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -1183,8 +1183,8 @@ func RecordHostScans(ctx context.Context, tx sql.Tx, scans []api.HostScan) error
uptime = CASE WHEN ? AND last_scan > 0 AND last_scan < ? THEN uptime + ? - last_scan ELSE uptime END,
last_scan = ?,
settings = CASE WHEN ? THEN ? ELSE settings END,
price_table = CASE WHEN ? THEN ? ELSE price_table END,
price_table_expiry = CASE WHEN ? AND price_table_expiry IS NOT NULL AND ? > price_table_expiry THEN ? ELSE price_table_expiry END,
price_table = CASE WHEN ? AND (price_table_expiry IS NULL OR ? > price_table_expiry) THEN ? ELSE price_table END,
price_table_expiry = CASE WHEN ? AND (price_table_expiry IS NULL OR ? > price_table_expiry) THEN ? ELSE price_table_expiry END,
successful_interactions = CASE WHEN ? THEN successful_interactions + 1 ELSE successful_interactions END,
failed_interactions = CASE WHEN ? THEN failed_interactions + 1 ELSE failed_interactions END
WHERE public_key = ?
Expand All @@ -1206,7 +1206,7 @@ func RecordHostScans(ctx context.Context, tx sql.Tx, scans []api.HostScan) error
scan.Success, scanTime, scanTime, // uptime
scanTime, // last_scan
scan.Success, HostSettings(scan.Settings), // settings
scan.Success, PriceTable(scan.PriceTable), // price_table
scan.Success, now, PriceTable(scan.PriceTable), // price_table
scan.Success, now, now, // price_table_expiry
scan.Success, // successful_interactions
!scan.Success, // failed_interactions
Expand Down
12 changes: 12 additions & 0 deletions stores/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,18 @@ func newTestSQLStore(t *testing.T, cfg testSQLStoreConfig) *testSQLStore {
}
}

func (s *testSQLStore) DB() *isql.DB {
switch db := s.bMain.(type) {
case *sqlite.MainDatabase:
return db.DB()
case *mysql.MainDatabase:
return db.DB()
default:
s.t.Fatal("unknown db type", db)
}
panic("unreachable")
}

func (s *testSQLStore) DBMetrics() *isql.DB {
switch db := s.bMetrics.(type) {
case *sqlite.MetricsDatabase:
Expand Down

0 comments on commit 74ed4bf

Please sign in to comment.