diff --git a/stores/metadata_test.go b/stores/metadata_test.go index 205812287a..b39816c1f3 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -8,7 +8,6 @@ import ( "fmt" "os" "reflect" - "sort" "strings" "testing" "time" @@ -4308,99 +4307,3 @@ func TestUpdateObjectReuseSlab(t *testing.T) { } } } - -func TestTypeCurrency(t *testing.T) { - ss := newTestSQLStore(t, defaultTestSQLStoreConfig) - defer ss.Close() - - // prepare the table - if isSQLite(ss.db) { - if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil { - t.Fatal(err) - } - } else { - if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil { - t.Fatal(err) - } - } - - // insert currencies in random order - if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil { - t.Fatal(err) - } - - // fetch currencies and assert they're sorted - var currencies []bCurrency - if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(¤cies).Error; err != nil { - t.Fatal(err) - } else if !sort.SliceIsSorted(currencies, func(i, j int) bool { - return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 - }) { - t.Fatal("currencies not sorted", currencies) - } - - // convenience variables - c0 := currencies[0] - c1 := currencies[1] - cM := currencies[2] - - tests := []struct { - a bCurrency - b bCurrency - cmp string - }{ - { - a: c0, - b: c1, - cmp: "<", - }, - { - a: c1, - b: c0, - cmp: ">", - }, - { - a: c0, - b: c1, - cmp: "!=", - }, - { - a: c1, - b: c1, - cmp: "=", - }, - { - a: c0, - b: cM, - cmp: "<", - }, - { - a: cM, - b: c0, - cmp: ">", - }, - { - a: cM, - b: cM, - cmp: "=", - }, - } - for i, test := range tests { - var result bool - query := fmt.Sprintf("SELECT ? %s ?", test.cmp) - if !isSQLite(ss.db) { - query = strings.Replace(query, "?", "HEX(?)", -1) - } - if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil { - t.Fatal(err) - } else if !result { - t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String()) - } else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 { - t.Fatal("invalid result") - } else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 { - t.Fatal("invalid result") - } else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 { - t.Fatal("invalid result") - } - } -} diff --git a/stores/types.go b/stores/types.go index 42a8d29e44..4bfe783e2b 100644 --- a/stores/types.go +++ b/stores/types.go @@ -17,6 +17,7 @@ import ( ) const ( + proofHashSize = 32 secretKeySize = 32 ) @@ -35,6 +36,7 @@ type ( balance big.Int unsigned64 uint64 // used for storing large uint64 values in sqlite secretKey []byte + merkleProof []types.Hash256 ) // GormDataType implements gorm.GormDataTypeInterface. @@ -341,6 +343,42 @@ func (u unsigned64) Value() (driver.Value, error) { return int64(u), nil } +// GormDataType implements gorm.GormDataTypeInterface. +func (mp *merkleProof) GormDataType() string { + return "bytes" +} + +// Scan scans value into mp, implements sql.Scanner interface. +func (mp *merkleProof) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("failed to unmarshal merkleProof value:", value)) + } else if len(bytes)%proofHashSize != 0 { + return fmt.Errorf("failed to unmarshal merkleProof value due to invalid number of bytes %v is not a multiple of %v: %v", len(bytes), proofHashSize, value) + } else if len(bytes) == 0 { + return errors.New("failed to unmarshal merkleProof value, no bytes found") + } + + n := len(bytes) / proofHashSize + hashes := make([]types.Hash256, n) + for i := 0; i < n; i++ { + copy(hashes[i][:], bytes[:proofHashSize]) + bytes = bytes[proofHashSize:] + } + *mp = hashes + return nil +} + +// Value returns a merkle proof value, implements driver.Valuer interface. +func (mp merkleProof) Value() (driver.Value, error) { + var i int + out := make([]byte, len(mp)*proofHashSize) + for _, ph := range mp { + i += copy(out[i:], ph[:]) + } + return out, nil +} + func (bCurrency) GormDataType() string { return "bytes" } diff --git a/stores/types_test.go b/stores/types_test.go new file mode 100644 index 0000000000..837e25e1ed --- /dev/null +++ b/stores/types_test.go @@ -0,0 +1,147 @@ +package stores + +import ( + "fmt" + "sort" + "strings" + "testing" + + "go.sia.tech/core/types" +) + +func TestTypeCurrency(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // prepare the table + if isSQLite(ss.db) { + if err := ss.db.Exec("CREATE TABLE currencies (id INTEGER PRIMARY KEY AUTOINCREMENT,c BLOB);").Error; err != nil { + t.Fatal(err) + } + } else { + if err := ss.db.Exec("CREATE TABLE currencies (id INT AUTO_INCREMENT PRIMARY KEY, c BLOB);").Error; err != nil { + t.Fatal(err) + } + } + + // insert currencies in random order + if err := ss.db.Exec("INSERT INTO currencies (c) VALUES (?),(?),(?);", bCurrency(types.MaxCurrency), bCurrency(types.NewCurrency64(1)), bCurrency(types.ZeroCurrency)).Error; err != nil { + t.Fatal(err) + } + + // fetch currencies and assert they're sorted + var currencies []bCurrency + if err := ss.db.Raw(`SELECT c FROM currencies ORDER BY c ASC`).Scan(¤cies).Error; err != nil { + t.Fatal(err) + } else if !sort.SliceIsSorted(currencies, func(i, j int) bool { + return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 + }) { + t.Fatal("currencies not sorted", currencies) + } + + // convenience variables + c0 := currencies[0] + c1 := currencies[1] + cM := currencies[2] + + tests := []struct { + a bCurrency + b bCurrency + cmp string + }{ + { + a: c0, + b: c1, + cmp: "<", + }, + { + a: c1, + b: c0, + cmp: ">", + }, + { + a: c0, + b: c1, + cmp: "!=", + }, + { + a: c1, + b: c1, + cmp: "=", + }, + { + a: c0, + b: cM, + cmp: "<", + }, + { + a: cM, + b: c0, + cmp: ">", + }, + { + a: cM, + b: cM, + cmp: "=", + }, + } + for i, test := range tests { + var result bool + query := fmt.Sprintf("SELECT ? %s ?", test.cmp) + if !isSQLite(ss.db) { + query = strings.Replace(query, "?", "HEX(?)", -1) + } + if err := ss.db.Raw(query, test.a, test.b).Scan(&result).Error; err != nil { + t.Fatal(err) + } else if !result { + t.Errorf("unexpected result in case %d/%d: expected %v %s %v to be true", i+1, len(tests), types.Currency(test.a).String(), test.cmp, types.Currency(test.b).String()) + } else if test.cmp == "<" && types.Currency(test.a).Cmp(types.Currency(test.b)) >= 0 { + t.Fatal("invalid result") + } else if test.cmp == ">" && types.Currency(test.a).Cmp(types.Currency(test.b)) <= 0 { + t.Fatal("invalid result") + } else if test.cmp == "=" && types.Currency(test.a).Cmp(types.Currency(test.b)) != 0 { + t.Fatal("invalid result") + } + } +} + +func TestTypeMerkleProof(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + // prepare the table + if isSQLite(ss.db) { + if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INTEGER PRIMARY KEY AUTOINCREMENT,merkle_proof BLOB);").Error; err != nil { + t.Fatal(err) + } + } else { + ss.db.Exec("DROP TABLE IF EXISTS merkle_proofs;") + if err := ss.db.Exec("CREATE TABLE merkle_proofs (id INT AUTO_INCREMENT PRIMARY KEY, merkle_proof BLOB);").Error; err != nil { + t.Fatal(err) + } + } + + // insert merkle proof + if err := ss.db.Exec("INSERT INTO merkle_proofs (id, merkle_proof) VALUES (1,?),(2,?);", merkleProof([]types.Hash256{}), merkleProof([]types.Hash256{{2}, {1}, {3}})).Error; err != nil { + t.Fatal(err) + } + + // fetch invalid proof + var proofs []merkleProof + if err := ss.db. + Raw(`SELECT merkle_proof FROM merkle_proofs WHERE id=1`). + Take(&proofs). + Error; err == nil || !strings.Contains(err.Error(), "no bytes found") { + t.Fatalf("expected error 'no bytes found', got '%v'", err) + } + + // fetch valid proof + if err := ss.db. + Raw(`SELECT merkle_proof FROM merkle_proofs WHERE id=2`). + Take(&proofs). + Error; err != nil { + t.Fatalf("unexpected error '%v'", err) + } else if proofs[0][0] != (types.Hash256{2}) || proofs[0][1] != (types.Hash256{1}) || proofs[0][2] != (types.Hash256{3}) { + t.Fatalf("unexpected proof %+v", proofs[0]) + } +}