Skip to content

Commit

Permalink
Merge pull request #982 from SiaFoundation/pj/mysql-stores-tests
Browse files Browse the repository at this point in the history
Run stores unit tests against MySQL
  • Loading branch information
ChrisSchinnerl authored Feb 22, 2024
2 parents 414090b + 7fb1349 commit 9edb677
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 97 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ jobs:
uses: n8maninger/action-golang-test@v1
with:
args: "-race;-short"
- name: Test Stores - MySQL
if: matrix.os == 'ubuntu-latest'
uses: n8maninger/action-golang-test@v1
env:
RENTERD_DB_URI: 127.0.0.1:3800
RENTERD_DB_USER: root
RENTERD_DB_PASSWORD: test
with:
package: "./stores"
args: "-race;-short"
- name: Test Integration
uses: n8maninger/action-golang-test@v1
with:
Expand Down
41 changes: 19 additions & 22 deletions stores/hostdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,16 @@ func TestSQLHostDB(t *testing.T) {

// Insert an announcement for the host and another one for an unknown
// host.
a := hostdb.Announcement{
Index: types.ChainIndex{
Height: 42,
ID: types.BlockID{1, 2, 3},
},
Timestamp: time.Now().UTC().Round(time.Second),
NetAddress: "address",
}
err = ss.insertTestAnnouncement(hk, a)
ann := newTestHostDBAnnouncement("address")
err = ss.insertTestAnnouncement(hk, ann)
if err != nil {
t.Fatal(err)
}

// Read the host and verify that the announcement related fields were
// set.
var h dbHost
tx := ss.db.Where("last_announcement = ? AND net_address = ?", a.Timestamp, a.NetAddress).Find(&h)
tx := ss.db.Where("last_announcement = ? AND net_address = ?", ann.Timestamp, ann.NetAddress).Find(&h)
if tx.Error != nil {
t.Fatal(tx.Error)
}
Expand Down Expand Up @@ -116,15 +109,15 @@ func TestSQLHostDB(t *testing.T) {

// Insert another announcement for an unknown host.
unknownKey := types.PublicKey{1, 4, 7}
err = ss.insertTestAnnouncement(unknownKey, a)
err = ss.insertTestAnnouncement(unknownKey, ann)
if err != nil {
t.Fatal(err)
}
h3, err := ss.Host(ctx, unknownKey)
if err != nil {
t.Fatal(err)
}
if h3.NetAddress != a.NetAddress {
if h3.NetAddress != ann.NetAddress {
t.Fatal("wrong net address")
}
if h3.KnownSince.IsZero() {
Expand Down Expand Up @@ -510,22 +503,18 @@ func TestInsertAnnouncements(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// Create announcements for 2 hosts.
// Create announcements for 3 hosts.
ann1 := announcement{
hostKey: publicKey(types.GeneratePrivateKey().PublicKey()),
announcement: hostdb.Announcement{
Index: types.ChainIndex{Height: 1, ID: types.BlockID{1}},
Timestamp: time.Now(),
NetAddress: "foo.bar:1000",
},
hostKey: publicKey(types.GeneratePrivateKey().PublicKey()),
announcement: newTestHostDBAnnouncement("foo.bar:1000"),
}
ann2 := announcement{
hostKey: publicKey(types.GeneratePrivateKey().PublicKey()),
announcement: hostdb.Announcement{},
announcement: newTestHostDBAnnouncement("bar.baz:1000"),
}
ann3 := announcement{
hostKey: publicKey(types.GeneratePrivateKey().PublicKey()),
announcement: hostdb.Announcement{},
announcement: newTestHostDBAnnouncement("quz.qux:1000"),
}

// Insert the first one and check that all fields are set.
Expand Down Expand Up @@ -1101,7 +1090,7 @@ func (s *SQLStore) addCustomTestHost(hk types.PublicKey, na string) error {
s.unappliedHostKeys[hk] = struct{}{}
s.unappliedAnnouncements = append(s.unappliedAnnouncements, []announcement{{
hostKey: publicKey(hk),
announcement: hostdb.Announcement{NetAddress: na},
announcement: newTestHostDBAnnouncement(na),
}}...)
s.lastSave = time.Now().Add(s.persistInterval * -2)
return s.applyUpdates(false)
Expand Down Expand Up @@ -1153,6 +1142,14 @@ func newTestHostAnnouncement(na modules.NetAddress) (modules.HostAnnouncement, t
}, sk
}

func newTestHostDBAnnouncement(addr string) hostdb.Announcement {
return hostdb.Announcement{
Index: types.ChainIndex{Height: 1, ID: types.BlockID{1}},
Timestamp: time.Now().UTC().Round(time.Second),
NetAddress: addr,
}
}

func newTestTransaction(ha modules.HostAnnouncement, sk types.PrivateKey) stypes.Transaction {
var buf bytes.Buffer
buf.Write(encoding.Marshal(ha))
Expand Down
4 changes: 4 additions & 0 deletions stores/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -1499,6 +1499,10 @@ func (s *SQLStore) RenameObjects(ctx context.Context, bucket, prefixOld, prefixN
gorm.Expr(sqlConcat(tx, "?", "SUBSTR(object_id, ?)")), prefixNew,
utf8.RuneCountInString(prefixOld)+1, prefixOld+"%",
utf8.RuneCountInString(prefixOld), prefixOld, sqlWhereBucket("objects", bucket))

if !isSQLite(tx) {
inner = tx.Raw("SELECT * FROM (?) as i", inner)
}
resp := tx.Model(&dbObject{}).
Where("object_id IN (?)", inner).
Delete(&dbObject{})
Expand Down
129 changes: 79 additions & 50 deletions stores/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/hex"
"errors"
"fmt"
"math"
"os"
"reflect"
"sort"
Expand All @@ -18,7 +17,6 @@ import (
rhpv2 "go.sia.tech/core/rhp/v2"
"go.sia.tech/core/types"
"go.sia.tech/renterd/api"
"go.sia.tech/renterd/hostdb"
"go.sia.tech/renterd/object"
"gorm.io/gorm"
"gorm.io/gorm/schema"
Expand Down Expand Up @@ -220,7 +218,7 @@ func TestSQLContractStore(t *testing.T) {
}

// Add an announcement.
err = ss.insertTestAnnouncement(hk, hostdb.Announcement{NetAddress: "address"})
err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -511,11 +509,11 @@ func TestRenewedContract(t *testing.T) {
hk, hk2 := hks[0], hks[1]

// Add announcements.
err = ss.insertTestAnnouncement(hk, hostdb.Announcement{NetAddress: "address"})
err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address"))
if err != nil {
t.Fatal(err)
}
err = ss.insertTestAnnouncement(hk2, hostdb.Announcement{NetAddress: "address2"})
err = ss.insertTestAnnouncement(hk2, newTestHostDBAnnouncement("address2"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1008,7 +1006,7 @@ func TestSQLMetadataStore(t *testing.T) {

one := uint(1)
expectedObj := dbObject{
DBBucketID: 1,
DBBucketID: ss.DefaultBucketID(),
Health: 1,
ObjectID: objID,
Key: obj1Key,
Expand Down Expand Up @@ -1169,6 +1167,7 @@ func TestSQLMetadataStore(t *testing.T) {
slabs[i].Shards[0].Model = Model{}
slabs[i].Shards[0].Contracts[0].Model = Model{}
slabs[i].Shards[0].Contracts[0].Host.Model = Model{}
slabs[i].Shards[0].Contracts[0].Host.LastAnnouncement = time.Time{}
slabs[i].HealthValidUntil = 0
}
if !reflect.DeepEqual(slab1, expectedObjSlab1) {
Expand Down Expand Up @@ -2213,10 +2212,9 @@ func TestUpdateSlab(t *testing.T) {
t.Fatal(err)
}
var s dbSlab
if err := ss.db.Model(&dbSlab{}).
if err := ss.db.Where(&dbSlab{Key: key}).
Joins("DBContractSet").
Preload("Shards").
Where("key = ?", key).
Take(&s).
Error; err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -2265,7 +2263,7 @@ func TestRecordContractSpending(t *testing.T) {
}

// Add an announcement.
err = ss.insertTestAnnouncement(hk, hostdb.Announcement{NetAddress: "address"})
err = ss.insertTestAnnouncement(hk, newTestHostDBAnnouncement("address"))
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -3900,15 +3898,15 @@ func TestSlabCleanupTrigger(t *testing.T) {
// create objects
obj1 := dbObject{
ObjectID: "1",
DBBucketID: 1,
DBBucketID: ss.DefaultBucketID(),
Health: 1,
}
if err := ss.db.Create(&obj1).Error; err != nil {
t.Fatal(err)
}
obj2 := dbObject{
ObjectID: "2",
DBBucketID: 1,
DBBucketID: ss.DefaultBucketID(),
Health: 1,
}
if err := ss.db.Create(&obj2).Error; err != nil {
Expand Down Expand Up @@ -3981,7 +3979,7 @@ func TestSlabCleanupTrigger(t *testing.T) {
}
obj3 := dbObject{
ObjectID: "3",
DBBucketID: 1,
DBBucketID: ss.DefaultBucketID(),
Health: 1,
}
if err := ss.db.Create(&obj3).Error; err != nil {
Expand Down Expand Up @@ -4120,11 +4118,11 @@ func TestUpdateObjectReuseSlab(t *testing.T) {

// fetch the object
var dbObj dbObject
if err := ss.db.Where("db_bucket_id", 1).Take(&dbObj).Error; err != nil {
if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()).Take(&dbObj).Error; err != nil {
t.Fatal(err)
} else if dbObj.ID != 1 {
t.Fatal("unexpected id", dbObj.ID)
} else if dbObj.DBBucketID != 1 {
} else if dbObj.DBBucketID != ss.DefaultBucketID() {
t.Fatal("bucket id mismatch", dbObj.DBBucketID)
} else if dbObj.ObjectID != "1" {
t.Fatal("object id mismatch", dbObj.ObjectID)
Expand Down Expand Up @@ -4226,7 +4224,7 @@ func TestUpdateObjectReuseSlab(t *testing.T) {

// fetch the object
var dbObj2 dbObject
if err := ss.db.Where("db_bucket_id", 1).
if err := ss.db.Where("db_bucket_id", ss.DefaultBucketID()).
Where("object_id", "2").
Take(&dbObj2).Error; err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -4310,63 +4308,94 @@ func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

// prepare the table
if isSQLite(ss.db) {
if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil {
t.Fatal(err)
}
} else {
if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil {
t.Fatal(err)
}
}

// insert currencies in random order
if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil {
t.Fatal(err)
}

// fetch currencies and assert they're sorted
var currencies []bCurrency
if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(&currencies).Error; err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}

// convenience variables
c0 := currencies[0]
c1 := currencies[1]
cM := currencies[2]

tests := []struct {
a types.Currency
b types.Currency
a bCurrency
b bCurrency
cmp string
}{
{
a: types.ZeroCurrency,
b: types.NewCurrency64(1),
a: c0,
b: c1,
cmp: "<",
},
{
a: types.NewCurrency64(1),
b: types.NewCurrency64(1),
a: c1,
b: c0,
cmp: ">",
},
{
a: c0,
b: c1,
cmp: "!=",
},
{
a: c1,
b: c1,
cmp: "=",
},
{
a: types.NewCurrency(math.MaxUint64, 0),
b: types.NewCurrency(0, math.MaxUint64),
a: c0,
b: cM,
cmp: "<",
},
{
a: types.NewCurrency(0, math.MaxUint64),
b: types.NewCurrency(math.MaxUint64, 0),
a: cM,
b: c0,
cmp: ">",
},
{
a: cM,
b: cM,
cmp: "=",
},
}
for _, test := range tests {
for i, test := range tests {
var result bool
err := ss.db.Raw("SELECT ? "+test.cmp+" ?", bCurrency(test.a), bCurrency(test.b)).Scan(&result).Error
if err != nil {
query := fmt.Sprintf("SELECT ? %s ?", test.cmp)
if !isSQLite(ss.db) {
query = strings.Replace(query, "?", "HEX(?)", -1)
}
if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil {
t.Fatal(err)
} else if !result {
t.Fatalf("unexpected result %v for %v %v %v", result, test.a, test.cmp, test.b)
} else if test.cmp == "<" && test.a.Cmp(test.b) >= 0 {
t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String())
} else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 {
t.Fatal("invalid result")
} else if test.cmp == ">" && test.a.Cmp(test.b) <= 0 {
} else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 {
t.Fatal("invalid result")
} else if test.cmp == "=" && test.a.Cmp(test.b) != 0 {
} else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 {
t.Fatal("invalid result")
}
}

c := func(c uint64) bCurrency {
return bCurrency(types.NewCurrency64(c))
}

var currencies []bCurrency
err := ss.db.Raw(`
WITH input(col) as
(values (?),(?),(?))
SELECT * FROM input ORDER BY col ASC
`, c(3), c(1), c(2)).Scan(&currencies).Error
if err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}
}
4 changes: 1 addition & 3 deletions stores/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,7 @@ func (s *SQLStore) findPeriods(table string, dst interface{}, start time.Time, n
WHERE ?
GROUP BY
p.period_start
ORDER BY
p.period_start ASC
) i ON %s.id = i.id
) i ON %s.id = i.id ORDER BY Period ASC
`, table, table, table, table),
unixTimeMS(start),
interval.Milliseconds(),
Expand Down
Loading

0 comments on commit 9edb677

Please sign in to comment.