Skip to content

Commit

Permalink
stores: update SearchHosts
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjan committed Mar 25, 2024
1 parent 6983036 commit 419b21e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 50 deletions.
2 changes: 1 addition & 1 deletion autopilot/contractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ func (c *contractor) runContractChecks(ctx context.Context, contracts []api.Cont
}

// fetch host checks
check, ok := host.Checks[c.ap.id]
check, ok := hostChecks[hk]
if !ok {
// this is only possible due to developer error, if there is no
// check the host would have been missing, so we treat it the same
Expand Down
51 changes: 16 additions & 35 deletions stores/hostdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,21 +487,14 @@ func (e *dbBlocklistEntry) blocks(h dbHost) bool {

// Host returns information about a host.
func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) {
var h dbHost

tx := ss.db.
WithContext(ctx).
Where(&dbHost{PublicKey: publicKey(hostKey)}).
Preload("Allowlist").
Preload("Blocklist").
Take(&h)
if errors.Is(tx.Error, gorm.ErrRecordNotFound) {
hosts, err := ss.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1)
if err != nil {
return api.Host{}, err
} else if len(hosts) == 0 {
return api.Host{}, api.ErrHostNotFound
} else if tx.Error != nil {
return api.Host{}, tx.Error
} else {
return hosts[0], nil
}

return h.convert(ss.isBlocked(h)), nil
}

func (ss *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) {
Expand Down Expand Up @@ -610,10 +603,6 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us
return nil, ErrNegativeOffset
}

// TODO PJ: use
_ = autopilotID
_ = usabilityMode

// validate filterMode
switch filterMode {
case api.HostFilterModeAllowed:
Expand All @@ -631,7 +620,7 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us
hostFilter(filterMode, ss.hasAllowlist(), ss.hasBlocklist()),
hostNetAddress(addressContains),
hostPublicKey(keyIn),
usabilityFilter(usabilityMode),
usabilityFilter(autopilotID, usabilityMode),
)

// preload allowlist and blocklist
Expand All @@ -641,23 +630,9 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us
Preload("Blocklist")
}

// filter checks
if autopilotID != "" {
query = query.Preload("Checks.DBAutopilot", "identifier = ?", autopilotID)
} else {
query = query.Preload("Checks.DBAutopilot")
}
// query = query.
// Preload("Checks.DBAutopilot").
// Scopes(
// autopilotFilter(autopilotID),
// usabilityFilter(usabilityMode),
// )

var hosts []api.Host
var fullHosts []dbHost
err := query.
Debug().
Offset(offset).
Limit(limit).
FindInBatches(&fullHosts, hostRetrievalBatchSize, func(tx *gorm.DB, batch int) error {
Expand Down Expand Up @@ -1144,13 +1119,19 @@ func hostFilter(filterMode string, hasAllowlist, hasBlocklist bool) func(*gorm.D
}
}

func usabilityFilter(usabilityMode string) func(*gorm.DB) *gorm.DB {
func usabilityFilter(autopilotID, usabilityMode string) func(*gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
switch usabilityMode {
case api.UsabilityFilterModeUsable:
db = db.Preload("Checks", "usability_blocked = ? AND usability_offline = ? AND usability_low_score = ? AND usability_redundant_ip = ? AND usability_gouging = ? AND usability_not_accepting_contracts = ? AND usability_not_announced = ? AND usability_not_completing_scan = ?", false, false, false, false, false, false, false, false)
db = db.
Joins("INNER JOIN host_checks hc on hc.db_host_id = hosts.id").
Joins("INNER JOIN autopilots a on a.id = hc.db_autopilot_id AND a.identifier = ?", autopilotID).
Where("hc.usability_blocked = ? AND hc.usability_offline = ? AND hc.usability_low_score = ? AND hc.usability_redundant_ip = ? AND hc.usability_gouging = ? AND hc.usability_not_accepting_contracts = ? AND hc.usability_not_announced = ? AND hc.usability_not_completing_scan = ?", false, false, false, false, false, false, false, false)
case api.UsabilityFilterModeUnusable:
db = db.Preload("Checks", "usability_blocked = ? OR usability_offline = ? OR usability_low_score = ? OR usability_redundant_ip = ? OR usability_gouging = ? OR usability_not_accepting_contracts = ? OR usability_not_announced = ? OR usability_not_completing_scan = ?", true, true, true, true, true, true, true, true)
db = db.
Joins("INNER JOIN host_checks hc on hc.db_host_id = hosts.id").
Joins("INNER JOIN autopilots a on a.id = hc.db_autopilot_id AND a.identifier = ?", autopilotID).
Where("hc.usability_blocked = ? OR hc.usability_offline = ? OR hc.usability_low_score = ? OR hc.usability_redundant_ip = ? OR hc.usability_gouging = ? OR hc.usability_not_accepting_contracts = ? OR hc.usability_not_announced = ? OR hc.usability_not_completing_scan = ?", true, true, true, true, true, true, true, true)
case api.UsabilityFilterModeAll:
// do nothing
}
Expand Down
24 changes: 10 additions & 14 deletions stores/hostdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,32 +408,28 @@ func TestSearchHosts(t *testing.T) {
his, err = ss.SearchHosts(context.Background(), ap1, api.HostFilterModeAll, api.UsabilityFilterModeUsable, "", nil, 0, -1)
if err != nil {
t.Fatal(err)
} else if cnt != 3 {
t.Fatal("unexpected", cnt)
} else if len(his) != 1 {
t.Fatal("unexpected", len(his))
}

// assert h1 and h2 have the expected checks
// assert h1 has the expected checks
if c1, ok := his[0].Checks[ap1]; !ok || c1 != h1c {
t.Fatal("unexpected", c1, ok)
} else if _, ok := his[1].Checks[ap1]; ok {
t.Fatal("unexpected", ok)
} else if _, ok := his[1].Checks[ap2]; ok {
t.Fatal("unexpected")
}

his, err = ss.SearchHosts(context.Background(), ap1, api.HostFilterModeAll, api.UsabilityFilterModeUnusable, "", nil, 0, -1)
if err != nil {
t.Fatal(err)
} else if cnt != 3 {
t.Fatal("unexpected", cnt)
} else if len(his) != 1 {
t.Fatal("unexpected", len(his))
} else if his[0].Host.PublicKey != hk2 {
t.Fatal("unexpected")
}

// assert h1 and h2 have the expected checks
if _, ok := his[0].Checks[ap1]; ok {
// assert only ap1 check is there
if _, ok := his[0].Checks[ap1]; !ok {
t.Fatal("unexpected")
} else if c2, ok := his[1].Checks[ap1]; !ok || c2 != h2c1 {
t.Fatal("unexpected", ok)
} else if _, ok := his[1].Checks[ap2]; ok {
} else if _, ok := his[0].Checks[ap2]; ok {
t.Fatal("unexpected")
}

Expand Down

0 comments on commit 419b21e

Please sign in to comment.