Skip to content

Commit

Permalink
store: rename bMain to db and db to gormDB
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisSchinnerl authored and n8maninger committed Jul 15, 2024
1 parent 4baa8ba commit f1cee0d
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 236 deletions.
6 changes: 3 additions & 3 deletions stores/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// Accounts returns all accounts from the db.
func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err error) {
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
accounts, err = tx.Accounts(ctx)
return err
})
Expand All @@ -21,15 +21,15 @@ func (s *SQLStore) Accounts(ctx context.Context) (accounts []api.Account, err er
// sync all accounts after an unclean shutdown and the bus will know not to
// apply drift.
func (s *SQLStore) SetUncleanShutdown(ctx context.Context) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.SetUncleanShutdown(ctx)
})
}

// SaveAccounts saves the given accounts in the db, overwriting any existing
// ones.
func (s *SQLStore) SaveAccounts(ctx context.Context, accounts []api.Account) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.SaveAccounts(ctx, accounts)
})
}
6 changes: 3 additions & 3 deletions stores/autopilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
)

func (s *SQLStore) Autopilots(ctx context.Context) (aps []api.Autopilot, _ error) {
err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
aps, err = tx.Autopilots(ctx)
return
})
return aps, err
}

func (s *SQLStore) Autopilot(ctx context.Context, id string) (ap api.Autopilot, _ error) {
err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
ap, err = tx.Autopilot(ctx, id)
return
})
Expand All @@ -32,7 +32,7 @@ func (s *SQLStore) UpdateAutopilot(ctx context.Context, ap api.Autopilot) error
if err := ap.Config.Validate(); err != nil {
return err
}
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.UpdateAutopilot(ctx, ap)
})
}
6 changes: 3 additions & 3 deletions stores/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var (

// ChainIndex returns the last stored chain index.
func (s *SQLStore) ChainIndex(ctx context.Context) (ci types.ChainIndex, err error) {
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
ci, err = tx.Tip(ctx)
return err
})
Expand All @@ -25,7 +25,7 @@ func (s *SQLStore) ChainIndex(ctx context.Context) (ci types.ChainIndex, err err
// ProcessChainUpdate returns a callback function that process a chain update
// inside a transaction.
func (s *SQLStore) ProcessChainUpdate(ctx context.Context, applyFn chain.ApplyChainUpdateFn) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.ProcessChainUpdate(ctx, applyFn)
})
}
Expand All @@ -39,7 +39,7 @@ func (s *SQLStore) UpdateChainState(reverted []chain.RevertUpdate, applied []cha

// ResetChainState deletes all chain data in the database.
func (s *SQLStore) ResetChainState(ctx context.Context) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.ResetChainState(ctx)
})
}
20 changes: 10 additions & 10 deletions stores/hostdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,14 @@ func (s *SQLStore) Host(ctx context.Context, hostKey types.PublicKey) (api.Host,
}

func (s *SQLStore) UpdateHostCheck(ctx context.Context, autopilotID string, hk types.PublicKey, hc api.HostCheck) (err error) {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.UpdateHostCheck(ctx, autopilotID, hk, hc)
})
}

// HostsForScanning returns the address of hosts for scanning.
func (s *SQLStore) HostsForScanning(ctx context.Context, maxLastScan time.Time, offset, limit int) (hosts []api.HostAddress, err error) {
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
hosts, err = tx.HostsForScanning(ctx, maxLastScan, offset, limit)
return err
})
Expand All @@ -293,7 +293,7 @@ func (s *SQLStore) ResetLostSectors(ctx context.Context, hk types.PublicKey) err

func (s *SQLStore) SearchHosts(ctx context.Context, autopilotID, filterMode, usabilityMode, addressContains string, keyIn []types.PublicKey, offset, limit int) ([]api.Host, error) {
var hosts []api.Host
err := s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
err := s.db.Transaction(ctx, func(tx sql.DatabaseTx) (err error) {
hosts, err = tx.SearchHosts(ctx, autopilotID, filterMode, usabilityMode, addressContains, keyIn, offset, limit)
return
})
Expand All @@ -310,7 +310,7 @@ func (s *SQLStore) RemoveOfflineHosts(ctx context.Context, minRecentFailures uin
if maxDowntime < 0 {
return 0, ErrNegativeMaxDowntime
}
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
n, err := tx.RemoveOfflineHosts(ctx, minRecentFailures, maxDowntime)
removed = uint64(n)
return err
Expand All @@ -323,7 +323,7 @@ func (s *SQLStore) UpdateHostAllowlistEntries(ctx context.Context, add, remove [
if len(add)+len(remove) == 0 && !clear {
return nil
}
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.UpdateHostAllowlistEntries(ctx, add, remove, clear)
})
}
Expand All @@ -333,35 +333,35 @@ func (s *SQLStore) UpdateHostBlocklistEntries(ctx context.Context, add, remove [
if len(add)+len(remove) == 0 && !clear {
return nil
}
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.UpdateHostBlocklistEntries(ctx, add, remove, clear)
})
}

func (s *SQLStore) HostAllowlist(ctx context.Context) (allowlist []types.PublicKey, err error) {
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
allowlist, err = tx.HostAllowlist(ctx)
return err
})
return
}

func (s *SQLStore) HostBlocklist(ctx context.Context) (blocklist []string, err error) {
err = s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
err = s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
blocklist, err = tx.HostBlocklist(ctx)
return err
})
return
}

func (s *SQLStore) RecordHostScans(ctx context.Context, scans []api.HostScan) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.RecordHostScans(ctx, scans)
})
}

func (s *SQLStore) RecordPriceTables(ctx context.Context, priceTableUpdate []api.HostPriceTableUpdate) error {
return s.bMain.Transaction(ctx, func(tx sql.DatabaseTx) error {
return s.db.Transaction(ctx, func(tx sql.DatabaseTx) error {
return tx.RecordPriceTables(ctx, priceTableUpdate)
})
}
Expand Down
56 changes: 28 additions & 28 deletions stores/hostdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestSQLHostDB(t *testing.T) {

// Fetch the host
var h dbHost
tx := ss.db.Where("net_address = ?", "address").Find(&h)
tx := ss.gormDB.Where("net_address = ?", "address").Find(&h)
if tx.Error != nil {
t.Fatal(tx.Error)
} else if types.PublicKey(h.PublicKey) != hk {
Expand Down Expand Up @@ -321,7 +321,7 @@ func TestSearchHosts(t *testing.T) {

// assert there are currently 3 checks
var cnt int64
err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error
err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error
if err != nil {
t.Fatal(err)
} else if cnt != 3 {
Expand Down Expand Up @@ -397,23 +397,23 @@ func TestSearchHosts(t *testing.T) {
}

// assert cascade delete on host
err = ss.db.Exec("DELETE FROM hosts WHERE public_key = ?", publicKey(types.PublicKey{1})).Error
err = ss.gormDB.Exec("DELETE FROM hosts WHERE public_key = ?", publicKey(types.PublicKey{1})).Error
if err != nil {
t.Fatal(err)
}
err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error
err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error
if err != nil {
t.Fatal(err)
} else if cnt != 2 {
t.Fatal("unexpected", cnt)
}

// assert cascade delete on autopilot
err = ss.db.Exec("DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2).Error
err = ss.gormDB.Exec("DELETE FROM autopilots WHERE identifier IN (?,?)", ap1, ap2).Error
if err != nil {
t.Fatal(err)
}
err = ss.db.Model(&dbHostCheck{}).Count(&cnt).Error
err = ss.gormDB.Model(&dbHostCheck{}).Count(&cnt).Error
if err != nil {
t.Fatal(err)
} else if cnt != 0 {
Expand Down Expand Up @@ -452,7 +452,7 @@ func TestRecordScan(t *testing.T) {
}

// Fetch the host directly to get the creation time.
h, err := hostByPubKey(ss.db, hk)
h, err := hostByPubKey(ss.gormDB, hk)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -589,11 +589,11 @@ func TestInsertAnnouncements(t *testing.T) {
ann3 := newTestAnnouncement(types.GeneratePrivateKey().PublicKey(), "")

// Insert the first one and check that all fields are set.
if err := insertAnnouncements(ss.db, []announcement{ann1}); err != nil {
if err := insertAnnouncements(ss.gormDB, []announcement{ann1}); err != nil {
t.Fatal(err)
}
var ann dbAnnouncement
if err := ss.db.Find(&ann).Error; err != nil {
if err := ss.gormDB.Find(&ann).Error; err != nil {
t.Fatal(err)
}
ann.Model = Model{} // ignore
Expand All @@ -607,12 +607,12 @@ func TestInsertAnnouncements(t *testing.T) {
t.Fatal("mismatch", cmp.Diff(ann, expectedAnn))
}
// Insert the first and second one.
if err := insertAnnouncements(ss.db, []announcement{ann1, ann2}); err != nil {
if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann2}); err != nil {
t.Fatal(err)
}

// Insert the first one twice. The second one again and the third one.
if err := insertAnnouncements(ss.db, []announcement{ann1, ann2, ann1, ann3}); err != nil {
if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann2, ann1, ann3}); err != nil {
t.Fatal(err)
}

Expand All @@ -627,7 +627,7 @@ func TestInsertAnnouncements(t *testing.T) {

// There should be 7 announcements total.
var announcements []dbAnnouncement
if err := ss.db.Find(&announcements).Error; err != nil {
if err := ss.gormDB.Find(&announcements).Error; err != nil {
t.Fatal(err)
}
if len(announcements) != 7 {
Expand All @@ -644,7 +644,7 @@ func TestInsertAnnouncements(t *testing.T) {
// Insert multiple announcements for host 1 - this asserts that the UNIQUE
// constraint on the blocklist table isn't triggered when inserting multiple
// announcements for a host that's on the blocklist
if err := insertAnnouncements(ss.db, []announcement{ann1, ann1}); err != nil {
if err := insertAnnouncements(ss.gormDB, []announcement{ann1, ann1}); err != nil {
t.Fatal(err)
}
}
Expand All @@ -661,7 +661,7 @@ func TestRemoveHosts(t *testing.T) {
}

// fetch the host and assert the recent downtime is zero
h, err := hostByPubKey(ss.db, hk)
h, err := hostByPubKey(ss.gormDB, hk)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -691,7 +691,7 @@ func TestRemoveHosts(t *testing.T) {
}

// fetch the host and assert the recent downtime is 30 minutes and he has 2 recent scan failures
h, err = hostByPubKey(ss.db, hk)
h, err = hostByPubKey(ss.gormDB, hk)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -746,7 +746,7 @@ func TestRemoveHosts(t *testing.T) {
}

// assert host is removed from the database
if _, err = hostByPubKey(ss.db, hk); err != gorm.ErrRecordNotFound {
if _, err = hostByPubKey(ss.gormDB, hk); err != gorm.ErrRecordNotFound {
t.Fatal("expected record not found error")
}
}
Expand Down Expand Up @@ -777,7 +777,7 @@ func TestSQLHostAllowlist(t *testing.T) {

numRelations := func() (cnt int64) {
t.Helper()
err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error
err := ss.gormDB.Table("host_allowlist_entry_hosts").Count(&cnt).Error
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -883,7 +883,7 @@ func TestSQLHostAllowlist(t *testing.T) {
}

// remove host 1
if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk1)}).Delete(&dbHost{}).Error; err != nil {
if err = ss.gormDB.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk1)}).Delete(&dbHost{}).Error; err != nil {
t.Fatal(err)
}
if numHosts() != 0 {
Expand Down Expand Up @@ -949,7 +949,7 @@ func TestSQLHostBlocklist(t *testing.T) {

numAllowlistRelations := func() (cnt int64) {
t.Helper()
err := ss.db.Table("host_allowlist_entry_hosts").Count(&cnt).Error
err := ss.gormDB.Table("host_allowlist_entry_hosts").Count(&cnt).Error
if err != nil {
t.Fatal(err)
}
Expand All @@ -958,7 +958,7 @@ func TestSQLHostBlocklist(t *testing.T) {

numBlocklistRelations := func() (cnt int64) {
t.Helper()
err := ss.db.Table("host_blocklist_entry_hosts").Count(&cnt).Error
err := ss.gormDB.Table("host_blocklist_entry_hosts").Count(&cnt).Error
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1067,7 +1067,7 @@ func TestSQLHostBlocklist(t *testing.T) {
}

// delete host 2 and assert the delete cascaded properly
if err = ss.db.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk2)}).Delete(&dbHost{}).Error; err != nil {
if err = ss.gormDB.Model(&dbHost{}).Where(&dbHost{PublicKey: publicKey(hk2)}).Delete(&dbHost{}).Error; err != nil {
t.Fatal(err)
}
if numHosts() != 2 {
Expand Down Expand Up @@ -1234,7 +1234,7 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error {
}

// fetch blocklists
allowlist, blocklist, err := getBlocklists(s.db)
allowlist, blocklist, err := getBlocklists(s.gormDB)
if err != nil {
return err
}
Expand All @@ -1246,7 +1246,7 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error {
dbAllowlist = append(dbAllowlist, entry)
}
}
if err := s.db.Model(&host).Association("Allowlist").Replace(&dbAllowlist); err != nil {
if err := s.gormDB.Model(&host).Association("Allowlist").Replace(&dbAllowlist); err != nil {
return err
}

Expand All @@ -1257,21 +1257,21 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error {
dbBlocklist = append(dbBlocklist, entry)
}
}
return s.db.Model(&host).Association("Blocklist").Replace(&dbBlocklist)
return s.gormDB.Model(&host).Association("Blocklist").Replace(&dbBlocklist)
}

// announceHost adds a host announcement to the database.
func (s *SQLStore) announceHost(hk types.PublicKey, na string) (host dbHost, err error) {
err = s.db.Transaction(func(tx *gorm.DB) error {
err = s.gormDB.Transaction(func(tx *gorm.DB) error {
host = dbHost{
PublicKey: publicKey(hk),
LastAnnouncement: time.Now().UTC().Round(time.Second),
NetAddress: na,
}
if err := s.db.Create(&host).Error; err != nil {
if err := s.gormDB.Create(&host).Error; err != nil {
return err
}
return s.db.Create(&dbAnnouncement{
return s.gormDB.Create(&dbAnnouncement{
HostKey: publicKey(hk),
BlockHeight: 42,
BlockID: types.BlockID{1, 2, 3}.String(),
Expand All @@ -1285,7 +1285,7 @@ func (s *SQLStore) announceHost(hk types.PublicKey, na string) (host dbHost, err
// interactions for all hosts is expensive in production.
func (db *SQLStore) hosts() ([]dbHost, error) {
var hosts []dbHost
tx := db.db.Find(&hosts)
tx := db.gormDB.Find(&hosts)
if tx.Error != nil {
return nil, tx.Error
}
Expand Down
Loading

0 comments on commit f1cee0d

Please sign in to comment.