From 9c31e346d6fb14eb89b64ca444344016564fca69 Mon Sep 17 00:00:00 2001 From: Christopher Schinnerl Date: Thu, 13 Jun 2024 11:43:33 +0200 Subject: [PATCH] Migrate allowlists and blocklists to raw SQL (#1298) Unfortunately this PR doesn't yet get rid of the `dbAllowlist` and `dbBlocklist` types. That's because `updateBlocklist` is still using them and it probably doesn't make sense to update that before merging `its-happening`. --- stores/hostdb.go | 86 +++++---------------------- stores/sql.go | 63 -------------------- stores/sql/database.go | 15 ++++- stores/sql/main.go | 45 +++++++++++++- stores/sql/mysql/main.go | 119 +++++++++++++++++++++++++++++++++++++- stores/sql/sqlite/main.go | 119 +++++++++++++++++++++++++++++++++++++- 6 files changed, 305 insertions(+), 142 deletions(-) diff --git a/stores/hostdb.go b/stores/hostdb.go index 9cbd24cb7..1ac420685 100644 --- a/stores/hostdb.go +++ b/stores/hostdb.go @@ -402,7 +402,7 @@ func (ss *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, func (ss *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { var hosts []api.Host err := ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) { - hosts, err = tx.SearchHosts(ctx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit, ss.hasAllowlist(), ss.hasBlocklist()) + hosts, err = tx.SearchHosts(ctx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) return }) return hosts, err @@ -431,37 +431,8 @@ func (ss *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove if len(add)+len(remove) == 0 && !clear { return nil } - defer ss.updateHasAllowlist(&err) - - // clear allowlist - if clear { - return ss.retryTransaction(ctx, func(tx *gorm.DB) error { - return tx.Where("TRUE").Delete(&dbAllowlistEntry{}).Error - }) - } - - var toInsert []dbAllowlistEntry - for _, entry := range add { - toInsert = append(toInsert, dbAllowlistEntry{Entry: publicKey(entry)}) - } - - toDelete := make([]publicKey, len(remove)) - for i, entry := range remove { - toDelete[i] = publicKey(entry) - } - - return ss.retryTransaction(ctx, func(tx *gorm.DB) error { - if len(toInsert) > 0 { - if err := tx.Create(&toInsert).Error; err != nil { - return err - } - } - if len(toDelete) > 0 { - if err := tx.Delete(&dbAllowlistEntry{}, "entry IN ?", toDelete).Error; err != nil { - return err - } - } - return nil + return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.UpdateHostAllowlistEntries(ctx, add, remove, clear) }) } @@ -470,55 +441,24 @@ func (ss *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove if len(add)+len(remove) == 0 && !clear { return nil } - defer ss.updateHasBlocklist(&err) - - // clear blocklist - if clear { - return ss.retryTransaction(ctx, func(tx *gorm.DB) error { - return tx.Where("TRUE").Delete(&dbBlocklistEntry{}).Error - }) - } - - var toInsert []dbBlocklistEntry - for _, entry := range add { - toInsert = append(toInsert, dbBlocklistEntry{Entry: entry}) - } - - return ss.retryTransaction(ctx, func(tx *gorm.DB) error { - if len(toInsert) > 0 { - if err := tx.Create(&toInsert).Error; err != nil { - return err - } - } - if len(remove) > 0 { - if err := tx.Delete(&dbBlocklistEntry{}, "entry IN ?", remove).Error; err != nil { - return err - } - } - return nil + return ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + return tx.UpdateHostBlocklistEntries(ctx, add, remove, clear) }) } func (ss *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) { - var pubkeys []publicKey - err = ss.db. - WithContext(ctx). - Model(&dbAllowlistEntry{}). - Pluck("entry", &pubkeys). - Error - - for _, pubkey := range pubkeys { - allowlist = append(allowlist, types.PublicKey(pubkey)) - } + err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + allowlist, err = tx.HostAllowlist(ctx) + return err + }) return } func (ss *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) { - err = ss.db. - WithContext(ctx). - Model(&dbBlocklistEntry{}). - Pluck("entry", &blocklist). - Error + err = ss.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error { + blocklist, err = tx.HostBlocklist(ctx) + return err + }) return } diff --git a/stores/sql.go b/stores/sql.go index 39ebe1558..f58d1504d 100644 --- a/stores/sql.go +++ b/stores/sql.go @@ -108,8 +108,6 @@ type ( wg sync.WaitGroup mu sync.Mutex - allowListCnt uint64 - blockListCnt uint64 lastPrunedAt time.Time closed bool @@ -237,16 +235,6 @@ func NewSQLStore(cfg Config) (*SQLStore, modules.ConsensusChangeID, error) { return nil, modules.ConsensusChangeID{}, err } - // Check allowlist and blocklist counts - allowlistCnt, err := tableCount(db, &dbAllowlistEntry{}) - if err != nil { - return nil, modules.ConsensusChangeID{}, err - } - blocklistCnt, err := tableCount(db, &dbBlocklistEntry{}) - if err != nil { - return nil, modules.ConsensusChangeID{}, err - } - // Fetch contract ids. var activeFCIDs, archivedFCIDs []fileContractID if err := db.Model(&dbContract{}). @@ -276,8 +264,6 @@ func NewSQLStore(cfg Config) (*SQLStore, modules.ConsensusChangeID, error) { knownContracts: isOurContract, lastSave: time.Now(), persistInterval: cfg.PersistInterval, - allowListCnt: uint64(allowlistCnt), - blockListCnt: uint64(blocklistCnt), settings: make(map[string]string), slabPruneSigChan: make(chan struct{}, 1), unappliedContractState: make(map[types.FileContractID]contractState), @@ -321,12 +307,6 @@ func isSQLite(db *gorm.DB) bool { } } -func (ss *SQLStore) hasAllowlist() bool { - ss.mu.Lock() - defer ss.mu.Unlock() - return ss.allowListCnt > 0 -} - func (s *SQLStore) initSlabPruning() error { // start pruning loop s.wg.Add(1) @@ -342,49 +322,6 @@ func (s *SQLStore) initSlabPruning() error { }) } -func (ss *SQLStore) updateHasAllowlist(err *error) { - if *err != nil { - return - } - - cnt, cErr := tableCount(ss.db, &dbAllowlistEntry{}) - if cErr != nil { - *err = cErr - return - } - - ss.mu.Lock() - ss.allowListCnt = uint64(cnt) - ss.mu.Unlock() -} - -func (ss *SQLStore) hasBlocklist() bool { - ss.mu.Lock() - defer ss.mu.Unlock() - return ss.blockListCnt > 0 -} - -func (ss *SQLStore) updateHasBlocklist(err *error) { - if *err != nil { - return - } - - cnt, cErr := tableCount(ss.db, &dbBlocklistEntry{}) - if cErr != nil { - *err = cErr - return - } - - ss.mu.Lock() - ss.blockListCnt = uint64(cnt) - ss.mu.Unlock() -} - -func tableCount(db *gorm.DB, model interface{}) (cnt int64, err error) { - err = db.Model(model).Count(&cnt).Error - return -} - // Close closes the underlying database connection of the store. func (s *SQLStore) Close() error { s.shutdownCtxCancel() diff --git a/stores/sql/database.go b/stores/sql/database.go index c1b0c3745..e29ce33ba 100644 --- a/stores/sql/database.go +++ b/stores/sql/database.go @@ -96,6 +96,13 @@ type ( // prefix and returns 'true' if any object was deleted. DeleteObjects(ctx context.Context, bucket, prefix string, limit int64) (bool, error) + // HostAllowlist returns the list of public keys of hosts on the + // allowlist. + HostAllowlist(ctx context.Context) ([]types.PublicKey, error) + + // HostBlocklist returns the list of host addresses on the blocklist. + HostBlocklist(ctx context.Context) ([]string, error) + // InsertObject inserts a new object into the database. InsertObject(ctx context.Context, bucket, key, contractSet string, dirID int64, o object.Object, mimeType, eTag string, md api.ObjectUserMetadata) error @@ -164,7 +171,7 @@ type ( SaveAccounts(ctx context.Context, accounts []api.Account) error // SearchHosts returns a list of hosts that match the provided filters - SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int, hasAllowList, hasBlocklist bool) ([]api.Host, error) + SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) // SetUncleanShutdown sets the clean shutdown flag on the accounts to // 'false' and also marks them as requiring a resync. @@ -178,6 +185,12 @@ type ( // one, fully overwriting the existing policy. UpdateBucketPolicy(ctx context.Context, bucket string, policy api.BucketPolicy) error + // UpdateHostAllowlistEntries updates the allowlist in the database + UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) error + + // UpdateHostBlocklistEntries updates the blocklist in the database + UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) error + // UpdateObjectHealth updates the health of all objects to the lowest // health of all its slabs. UpdateObjectHealth(ctx context.Context) error diff --git a/stores/sql/main.go b/stores/sql/main.go index 548b5b347..a605e56df 100644 --- a/stores/sql/main.go +++ b/stores/sql/main.go @@ -305,6 +305,42 @@ func CopyObject(ctx context.Context, tx sql.Tx, srcBucket, dstBucket, srcKey, ds return fetchMetadata(dstObjID) } +func HostAllowlist(ctx context.Context, tx sql.Tx) ([]types.PublicKey, error) { + rows, err := tx.Query(ctx, "SELECT entry FROM host_allowlist_entries") + if err != nil { + return nil, fmt.Errorf("failed to fetch host allowlist: %w", err) + } + defer rows.Close() + + var allowlist []types.PublicKey + for rows.Next() { + var pk PublicKey + if err := rows.Scan(&pk); err != nil { + return nil, fmt.Errorf("failed to scan public key: %w", err) + } + allowlist = append(allowlist, types.PublicKey(pk)) + } + return allowlist, nil +} + +func HostBlocklist(ctx context.Context, tx sql.Tx) ([]string, error) { + rows, err := tx.Query(ctx, "SELECT entry FROM host_blocklist_entries") + if err != nil { + return nil, fmt.Errorf("failed to fetch host blocklist: %w", err) + } + defer rows.Close() + + var blocklist []string + for rows.Next() { + var entry string + if err := rows.Scan(&entry); err != nil { + return nil, fmt.Errorf("failed to scan blocklist entry: %w", err) + } + blocklist = append(blocklist, entry) + } + return blocklist, nil +} + func HostsForScanning(ctx context.Context, tx sql.Tx, maxLastScan time.Time, offset, limit int) ([]api.HostAddress, error) { if offset < 0 { return nil, ErrNegativeOffset @@ -941,11 +977,18 @@ func RemoveOfflineHosts(ctx context.Context, tx sql.Tx, minRecentFailures uint64 return res.RowsAffected() } -func SearchHosts(ctx context.Context, tx sql.Tx, autopilot, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int, hasAllowlist, hasBlocklist bool) ([]api.Host, error) { +func SearchHosts(ctx context.Context, tx sql.Tx, autopilot, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { if offset < 0 { return nil, ErrNegativeOffset } + var hasAllowlist, hasBlocklist bool + if err := tx.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM host_allowlist_entries)").Scan(&hasAllowlist); err != nil { + return nil, fmt.Errorf("failed to check for allowlist: %w", err) + } else if err := tx.QueryRow(ctx, "SELECT EXISTS (SELECT 1 FROM host_blocklist_entries)").Scan(&hasBlocklist); err != nil { + return nil, fmt.Errorf("failed to check for blocklist: %w", err) + } + // validate filterMode switch filterMode { case api.HostFilterModeAllowed: diff --git a/stores/sql/mysql/main.go b/stores/sql/mysql/main.go index 5f41cc77d..c7fe51f99 100644 --- a/stores/sql/mysql/main.go +++ b/stores/sql/mysql/main.go @@ -293,6 +293,14 @@ func (tx *MainDatabaseTx) DeleteObjects(ctx context.Context, bucket string, key } } +func (tx *MainDatabaseTx) HostAllowlist(ctx context.Context) ([]types.PublicKey, error) { + return ssql.HostAllowlist(ctx, tx) +} + +func (tx *MainDatabaseTx) HostBlocklist(ctx context.Context) ([]string, error) { + return ssql.HostBlocklist(ctx, tx) +} + func (tx *MainDatabaseTx) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) ([]api.HostAddress, error) { return ssql.HostsForScanning(ctx, tx, maxLastScan, offset, limit) } @@ -569,8 +577,8 @@ func (tx MainDatabaseTx) SaveAccounts(ctx context.Context, accounts []api.Accoun return nil } -func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int, hasAllowlist, hasBlocklist bool) ([]api.Host, error) { - return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit, hasAllowlist, hasBlocklist) +func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { + return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) } func (tx *MainDatabaseTx) SetUncleanShutdown(ctx context.Context) error { @@ -599,6 +607,113 @@ func (tx *MainDatabaseTx) UpdateBucketPolicy(ctx context.Context, bucket string, return ssql.UpdateBucketPolicy(ctx, tx, bucket, bp) } +func (tx *MainDatabaseTx) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) error { + if clear { + if _, err := tx.Exec(ctx, "DELETE FROM host_allowlist_entries"); err != nil { + return fmt.Errorf("failed to clear host allowlist entries: %w", err) + } + } + + if len(add) > 0 { + insertStmt, err := tx.Prepare(ctx, "INSERT INTO host_allowlist_entries (entry) VALUES (?) ON DUPLICATE KEY UPDATE id = last_insert_id(id)") + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer insertStmt.Close() + joinStmt, err := tx.Prepare(ctx, ` + INSERT IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) + SELECT ?, id FROM ( + SELECT id + FROM hosts + WHERE public_key = ? + ) AS _`) + if err != nil { + return fmt.Errorf("failed to prepare join statement: %w", err) + } + defer joinStmt.Close() + + for _, pk := range add { + if res, err := insertStmt.Exec(ctx, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to insert host allowlist entry: %w", err) + } else if entryID, err := res.LastInsertId(); err != nil { + return fmt.Errorf("failed to fetch host allowlist entry id: %w", err) + } else if _, err := joinStmt.Exec(ctx, entryID, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to join host allowlist entry: %w", err) + } + } + } + + if !clear && len(remove) > 0 { + deleteStmt, err := tx.Prepare(ctx, "DELETE FROM host_allowlist_entries WHERE entry = ?") + if err != nil { + return fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer deleteStmt.Close() + + for _, pk := range remove { + if _, err := deleteStmt.Exec(ctx, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to delete host allowlist entry: %w", err) + } + } + } + return nil +} + +func (tx *MainDatabaseTx) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) error { + if clear { + if _, err := tx.Exec(ctx, "DELETE FROM host_blocklist_entries"); err != nil { + return fmt.Errorf("failed to clear host blocklist entries: %w", err) + } + } + + if len(add) > 0 { + insertStmt, err := tx.Prepare(ctx, "INSERT INTO host_blocklist_entries (entry) VALUES (?) ON DUPLICATE KEY UPDATE id = last_insert_id(id)") + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer insertStmt.Close() + joinStmt, err := tx.Prepare(ctx, ` + INSERT IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) + SELECT ?, id FROM ( + SELECT id + FROM hosts + WHERE net_address=? OR + SUBSTRING_INDEX(net_address,':',1) = ? OR + SUBSTRING_INDEX(net_address,':',1) LIKE ? + ) AS _ + `) + if err != nil { + return fmt.Errorf("failed to prepare join statement: %w", err) + } + defer joinStmt.Close() + + for _, entry := range add { + if res, err := insertStmt.Exec(ctx, entry); err != nil { + return fmt.Errorf("failed to insert host blocklist entry: %w", err) + } else if entryID, err := res.LastInsertId(); err != nil { + return fmt.Errorf("failed to fetch host blocklist entry id: %w", err) + } else if _, err := joinStmt.Exec(ctx, entryID, entry, entry, fmt.Sprintf("%%.%s", entry)); err != nil { + return fmt.Errorf("failed to join host blocklist entry: %w", err) + } + } + } + + if !clear && len(remove) > 0 { + deleteStmt, err := tx.Prepare(ctx, "DELETE FROM host_blocklist_entries WHERE entry = ?") + if err != nil { + return fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer deleteStmt.Close() + + for _, entry := range remove { + if _, err := deleteStmt.Exec(ctx, entry); err != nil { + return fmt.Errorf("failed to delete host blocklist entry: %w", err) + } + } + } + return nil +} + func (tx *MainDatabaseTx) UpdateObjectHealth(ctx context.Context) error { return ssql.UpdateObjectHealth(ctx, tx) } diff --git a/stores/sql/sqlite/main.go b/stores/sql/sqlite/main.go index de9476495..c14ee5b05 100644 --- a/stores/sql/sqlite/main.go +++ b/stores/sql/sqlite/main.go @@ -282,6 +282,14 @@ func (tx *MainDatabaseTx) DeleteObjects(ctx context.Context, bucket string, key } } +func (tx *MainDatabaseTx) HostAllowlist(ctx context.Context) ([]types.PublicKey, error) { + return ssql.HostAllowlist(ctx, tx) +} + +func (tx *MainDatabaseTx) HostBlocklist(ctx context.Context) ([]string, error) { + return ssql.HostBlocklist(ctx, tx) +} + func (tx *MainDatabaseTx) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) ([]api.HostAddress, error) { return ssql.HostsForScanning(ctx, tx, maxLastScan, offset, limit) } @@ -567,8 +575,8 @@ func (tx *MainDatabaseTx) SaveAccounts(ctx context.Context, accounts []api.Accou return nil } -func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int, hasAllowlist, hasBlocklist bool) ([]api.Host, error) { - return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit, hasAllowlist, hasBlocklist) +func (tx *MainDatabaseTx) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) { + return ssql.SearchHosts(ctx, tx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit) } func (tx *MainDatabaseTx) SetUncleanShutdown(ctx context.Context) error { @@ -596,6 +604,113 @@ func (tx *MainDatabaseTx) UpdateAutopilot(ctx context.Context, ap api.Autopilot) func (tx *MainDatabaseTx) UpdateBucketPolicy(ctx context.Context, bucket string, policy api.BucketPolicy) error { return ssql.UpdateBucketPolicy(ctx, tx, bucket, policy) } + +func (tx *MainDatabaseTx) UpdateHostAllowlistEntries(ctx context.Context, add, remove []types.PublicKey, clear bool) error { + if clear { + if _, err := tx.Exec(ctx, "DELETE FROM host_allowlist_entries"); err != nil { + return fmt.Errorf("failed to clear host allowlist entries: %w", err) + } + } + + if len(add) > 0 { + insertStmt, err := tx.Prepare(ctx, "INSERT INTO host_allowlist_entries (entry) VALUES (?) ON CONFLICT(entry) DO UPDATE SET id = id RETURNING id") + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer insertStmt.Close() + joinStmt, err := tx.Prepare(ctx, ` + INSERT OR IGNORE INTO host_allowlist_entry_hosts (db_allowlist_entry_id, db_host_id) + SELECT ?, id FROM ( + SELECT id + FROM hosts + WHERE public_key = ? + )`) + if err != nil { + return fmt.Errorf("failed to prepare join statement: %w", err) + } + defer joinStmt.Close() + + for _, pk := range add { + if res, err := insertStmt.Exec(ctx, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to insert host allowlist entry: %w", err) + } else if entryID, err := res.LastInsertId(); err != nil { + return fmt.Errorf("failed to fetch host allowlist entry id: %w", err) + } else if _, err := joinStmt.Exec(ctx, entryID, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to join host allowlist entry: %w", err) + } + } + } + + if !clear && len(remove) > 0 { + deleteStmt, err := tx.Prepare(ctx, "DELETE FROM host_allowlist_entries WHERE entry = ?") + if err != nil { + return fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer deleteStmt.Close() + + for _, pk := range remove { + if _, err := deleteStmt.Exec(ctx, ssql.PublicKey(pk)); err != nil { + return fmt.Errorf("failed to delete host allowlist entry: %w", err) + } + } + } + return nil +} + +func (tx *MainDatabaseTx) UpdateHostBlocklistEntries(ctx context.Context, add, remove []string, clear bool) error { + if clear { + if _, err := tx.Exec(ctx, "DELETE FROM host_blocklist_entries"); err != nil { + return fmt.Errorf("failed to clear host blocklist entries: %w", err) + } + } + + if len(add) > 0 { + insertStmt, err := tx.Prepare(ctx, "INSERT INTO host_blocklist_entries (entry) VALUES (?) ON CONFLICT(entry) DO UPDATE SET id = id RETURNING id") + if err != nil { + return fmt.Errorf("failed to prepare insert statement: %w", err) + } + defer insertStmt.Close() + joinStmt, err := tx.Prepare(ctx, ` + INSERT OR IGNORE INTO host_blocklist_entry_hosts (db_blocklist_entry_id, db_host_id) + SELECT ?, id FROM ( + SELECT id + FROM hosts + WHERE net_address == ? OR + rtrim(rtrim(net_address, replace(net_address, ':', '')),':') == ? OR + rtrim(rtrim(net_address, replace(net_address, ':', '')),':') LIKE ? + )`) + if err != nil { + return fmt.Errorf("failed to prepare join statement: %w", err) + } + defer joinStmt.Close() + + for _, entry := range add { + if res, err := insertStmt.Exec(ctx, entry); err != nil { + return fmt.Errorf("failed to insert host blocklist entry: %w", err) + } else if entryID, err := res.LastInsertId(); err != nil { + return fmt.Errorf("failed to fetch host blocklist entry id: %w", err) + } else if _, err := joinStmt.Exec(ctx, entryID, entry, entry, fmt.Sprintf("%%.%s", entry)); err != nil { + return fmt.Errorf("failed to join host blocklist entry: %w", err) + } + } + } + + if !clear && len(remove) > 0 { + deleteStmt, err := tx.Prepare(ctx, "DELETE FROM host_blocklist_entries WHERE entry = ?") + if err != nil { + return fmt.Errorf("failed to prepare delete statement: %w", err) + } + defer deleteStmt.Close() + + for _, entry := range remove { + if _, err := deleteStmt.Exec(ctx, entry); err != nil { + return fmt.Errorf("failed to delete host blocklist entry: %w", err) + } + } + } + return nil +} + func (tx *MainDatabaseTx) UpdateObjectHealth(ctx context.Context) error { return ssql.UpdateObjectHealth(ctx, tx) }