Skip to content

Commit

Permalink
Merge pull request #964 from SiaFoundation/chris/binary-currency
Browse files Browse the repository at this point in the history
New binary currency type for SQL
  • Loading branch information
ChrisSchinnerl authored Feb 16, 2024
2 parents 73fd775 + ea173eb commit e19ff56
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
61 changes: 61 additions & 0 deletions stores/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"encoding/hex"
"errors"
"fmt"
"math"
"os"
"reflect"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -4300,3 +4302,62 @@ func TestUpdateObjectReuseSlab(t *testing.T) {
}
}
}

func TestTypeCurrency(t *testing.T) {
ss := newTestSQLStore(t, defaultTestSQLStoreConfig)
defer ss.Close()

tests := []struct {
a types.Currency
b types.Currency
cmp string
}{
{
a: types.ZeroCurrency,
b: types.NewCurrency64(1),
cmp: "<",
},
{
a: types.NewCurrency64(1),
b: types.NewCurrency64(1),
cmp: "=",
},
{
a: types.NewCurrency(0, math.MaxUint64),
b: types.NewCurrency(math.MaxUint64, 0),
cmp: "<",
},
{
a: types.NewCurrency(math.MaxUint64, 0),
b: types.NewCurrency(0, math.MaxUint64),
cmp: ">",
},
}
for _, test := range tests {
var result bool
err := ss.db.Raw("SELECT ? "+test.cmp+" ?", bCurrency(test.a), bCurrency(test.b)).Scan(&result).Error
if err != nil {
t.Fatal(err)
} else if !result {
t.Fatal("unexpected result", result)
}
}

c := func(c uint64) bCurrency {
return bCurrency(types.NewCurrency64(c))
}

var currencies []bCurrency
err := ss.db.Raw(`
WITH input(col) as
(values (?),(?),(?))
SELECT * FROM input ORDER BY col ASC
`, c(3), c(1), c(2)).Scan(&currencies).Error
if err != nil {
t.Fatal(err)
} else if !sort.SliceIsSorted(currencies, func(i, j int) bool {
return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0
}) {
t.Fatal("currencies not sorted", currencies)
}
}
28 changes: 28 additions & 0 deletions stores/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stores

import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
Expand All @@ -25,6 +26,7 @@ type (
unixTimeMS time.Time
datetime time.Time
currency types.Currency
bCurrency types.Currency
fileContractID types.FileContractID
hash256 types.Hash256
publicKey types.PublicKey
Expand Down Expand Up @@ -338,3 +340,29 @@ func (u *unsigned64) Scan(value interface{}) error {
func (u unsigned64) Value() (driver.Value, error) {
return int64(u), nil
}

func (bCurrency) GormDataType() string {
return "bytes"
}

// Scan implements the sql.Scanner interface.
func (sc *bCurrency) Scan(src any) error {
buf, ok := src.([]byte)
if !ok {
return fmt.Errorf("cannot scan %T to Currency", src)
} else if len(buf) != 16 {
return fmt.Errorf("cannot scan %d bytes to Currency", len(buf))
}

sc.Lo = binary.LittleEndian.Uint64(buf[:8])
sc.Hi = binary.LittleEndian.Uint64(buf[8:])
return nil
}

// Value implements the driver.Valuer interface.
func (sc bCurrency) Value() (driver.Value, error) {
buf := make([]byte, 16)
binary.LittleEndian.PutUint64(buf[:8], sc.Lo)
binary.LittleEndian.PutUint64(buf[8:], sc.Hi)
return buf, nil
}

0 comments on commit e19ff56

Please sign in to comment.