From 9a1ec85e51a4d99e31eb2e454da967b2fd9bdd4a Mon Sep 17 00:00:00 2001 From: PJ Date: Mon, 25 Mar 2024 21:23:47 +0100 Subject: [PATCH] stores: update SearchHosts --- autopilot/contractor.go | 2 +- stores/hostdb.go | 51 +++++++++++++---------------------------- stores/hostdb_test.go | 31 +++++++++++++------------ 3 files changed, 33 insertions(+), 51 deletions(-) diff --git a/autopilot/contractor.go b/autopilot/contractor.go index 0e85e43022..2be7fe9e9f 100644 --- a/autopilot/contractor.go +++ b/autopilot/contractor.go @@ -720,7 +720,7 @@ func (c *contractor) runContractChecks(ctx context.Context, contracts []api.Cont } // fetch host checks - check, ok := host.Checks[c.ap.id] + check, ok := hostChecks[hk] if !ok { // this is only possible due to developer error, if there is no // check the host would have been missing, so we treat it the same diff --git a/stores/hostdb.go b/stores/hostdb.go index 822e97f95a..ee02776246 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -487,21 +487,14 @@ func (e *dbBlocklistEntry) blocks(h dbHost) bool { // Host returns information about a host. func (ss *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host, error) { - var h dbHost - - tx := ss.db. - WithContext(ctx). - Where(&dbHost{PublicKey: publicKey(hostKey)}). - Preload("Allowlist"). - Preload("Blocklist"). - Take(&h) - if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + hosts, err := ss.SearchHosts(ctx, "", api.HostFilterModeAll, api.UsabilityFilterModeAll, "", []types.PublicKey{hostKey}, 0, 1) + if err != nil { + return api.Host{}, err + } else if len(hosts) == 0 { return api.Host{}, api.ErrHostNotFound - } else if tx.Error != nil { - return api.Host{}, tx.Error + } else { + return hosts[0], nil } - - return h.convert(ss.isBlocked(h)), nil } func (ss *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) { @@ -610,10 +603,6 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us return nil, ErrNegativeOffset } - // TODO PJ: use - _ = autopilotID - _ = usabilityMode - // validate filterMode switch filterMode { case api.HostFilterModeAllowed: @@ -631,7 +620,7 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us hostFilter(filterMode, ss.hasAllowlist(), ss.hasBlocklist()), hostNetAddress(addressContains), hostPublicKey(keyIn), - usabilityFilter(usabilityMode), + usabilityFilter(autopilotID, usabilityMode), ) // preload allowlist and blocklist @@ -641,23 +630,9 @@ func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, us Preload("Blocklist") } - // filter checks - if autopilotID != "" { - query = query.Preload("Checks.DBAutopilot", "identifier = ?", autopilotID) - } else { - query = query.Preload("Checks.DBAutopilot") - } - // query = query. - // Preload("Checks.DBAutopilot"). - // Scopes( - // autopilotFilter(autopilotID), - // usabilityFilter(usabilityMode), - // ) - var hosts []api.Host var fullHosts []dbHost err := query. - Debug(). Offset(offset). Limit(limit). FindInBatches(&fullHosts, hostRetrievalBatchSize, func(tx *gorm.DB, batch int) error { @@ -1144,13 +1119,19 @@ func hostFilter(filterMode string, hasAllowlist, hasBlocklist bool) func(*gorm.D } } -func usabilityFilter(usabilityMode string) func(*gorm.DB) *gorm.DB { +func usabilityFilter(autopilotID, usabilityMode string) func(*gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { switch usabilityMode { case api.UsabilityFilterModeUsable: - db = db.Preload("Checks", "usability_blocked = ? AND usability_offline = ? AND usability_low_score = ? AND usability_redundant_ip = ? AND usability_gouging = ? AND usability_not_accepting_contracts = ? AND usability_not_announced = ? AND usability_not_completing_scan = ?", false, false, false, false, false, false, false, false) + db = db. + Joins("INNER JOIN host_checks hc on hc.db_host_id = hosts.id"). + Joins("INNER JOIN autopilots a on a.id = hc.db_autopilot_id AND a.identifier = ?", autopilotID). + Where("hc.usability_blocked = ? AND hc.usability_offline = ? AND hc.usability_low_score = ? AND hc.usability_redundant_ip = ? AND hc.usability_gouging = ? AND hc.usability_not_accepting_contracts = ? AND hc.usability_not_announced = ? AND hc.usability_not_completing_scan = ?", false, false, false, false, false, false, false, false) case api.UsabilityFilterModeUnusable: - db = db.Preload("Checks", "usability_blocked = ? OR usability_offline = ? OR usability_low_score = ? OR usability_redundant_ip = ? OR usability_gouging = ? OR usability_not_accepting_contracts = ? OR usability_not_announced = ? OR usability_not_completing_scan = ?", true, true, true, true, true, true, true, true) + db = db. + Joins("INNER JOIN host_checks hc on hc.db_host_id = hosts.id"). + Joins("INNER JOIN autopilots a on a.id = hc.db_autopilot_id AND a.identifier = ?", autopilotID). + Where("hc.usability_blocked = ? OR hc.usability_offline = ? OR hc.usability_low_score = ? OR hc.usability_redundant_ip = ? OR hc.usability_gouging = ? OR hc.usability_not_accepting_contracts = ? OR hc.usability_not_announced = ? OR hc.usability_not_completing_scan = ?", true, true, true, true, true, true, true, true) case api.UsabilityFilterModeAll: // do nothing } diff --git a/stores/hostdb_test.go b/stores/hostdb_test.go index 196170b01e..a6ed0e998a 100644 --- a/stores/hostdb_test.go +++ b/stores/hostdb_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "os" "reflect" "testing" "time" @@ -242,7 +243,11 @@ func TestSQLHosts(t *testing.T) { // TestSearchHosts is a unit test for SearchHosts. func TestSearchHosts(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + cfg := defaultTestSQLStoreConfig + cfg.persistent = true + cfg.dir = "/Users/peterjan/testing2" + os.RemoveAll(cfg.dir) + ss := newTestSQLStore(t, cfg) defer ss.Close() ctx := context.Background() @@ -408,32 +413,28 @@ func TestSearchHosts(t *testing.T) { his, err = ss.SearchHosts(context.Background(), ap1, api.HostFilterModeAll, api.UsabilityFilterModeUsable, "", nil, 0, -1) if err != nil { t.Fatal(err) - } else if cnt != 3 { - t.Fatal("unexpected", cnt) + } else if len(his) != 1 { + t.Fatal("unexpected", len(his)) } - // assert h1 and h2 have the expected checks + // assert h1 has the expected checks if c1, ok := his[0].Checks[ap1]; !ok || c1 != h1c { t.Fatal("unexpected", c1, ok) - } else if _, ok := his[1].Checks[ap1]; ok { - t.Fatal("unexpected", ok) - } else if _, ok := his[1].Checks[ap2]; ok { - t.Fatal("unexpected") } his, err = ss.SearchHosts(context.Background(), ap1, api.HostFilterModeAll, api.UsabilityFilterModeUnusable, "", nil, 0, -1) if err != nil { t.Fatal(err) - } else if cnt != 3 { - t.Fatal("unexpected", cnt) + } else if len(his) != 1 { + t.Fatal("unexpected", len(his)) + } else if his[0].Host.PublicKey != hk2 { + t.Fatal("unexpected") } - // assert h1 and h2 have the expected checks - if _, ok := his[0].Checks[ap1]; ok { + // assert only ap1 check is there + if _, ok := his[0].Checks[ap1]; !ok { t.Fatal("unexpected") - } else if c2, ok := his[1].Checks[ap1]; !ok || c2 != h2c1 { - t.Fatal("unexpected", ok) - } else if _, ok := his[1].Checks[ap2]; ok { + } else if _, ok := his[0].Checks[ap2]; ok { t.Fatal("unexpected") }