From ea173eb3cc3beac3a1170260995c5f5d59faf99d Mon Sep 17 00:00:00 2001 From: Chris Schinnerl Date: Wed, 14 Feb 2024 16:08:48 +0100 Subject: [PATCH] stores: binary currency type --- stores/metadata_test.go | 61 +++++++++++++++++++++++++++++++++++++++++ stores/types.go | 28 +++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/stores/metadata_test.go b/stores/metadata_test.go index 96b06c4ec..0b0a696a5 100644 --- a/stores/metadata_test.go +++ b/stores/metadata_test.go @@ -6,8 +6,10 @@ import ( "encoding/hex" "errors" "fmt" + "math" "os" "reflect" + "sort" "strings" "testing" "time" @@ -4284,3 +4286,62 @@ func TestUpdateObjectReuseSlab(t *testing.T) { } } } + +func TestTypeCurrency(t *testing.T) { + ss := newTestSQLStore(t, defaultTestSQLStoreConfig) + defer ss.Close() + + tests := []struct { + a types.Currency + b types.Currency + cmp string + }{ + { + a: types.ZeroCurrency, + b: types.NewCurrency64(1), + cmp: "<", + }, + { + a: types.NewCurrency64(1), + b: types.NewCurrency64(1), + cmp: "=", + }, + { + a: types.NewCurrency(0, math.MaxUint64), + b: types.NewCurrency(math.MaxUint64, 0), + cmp: "<", + }, + { + a: types.NewCurrency(math.MaxUint64, 0), + b: types.NewCurrency(0, math.MaxUint64), + cmp: ">", + }, + } + for _, test := range tests { + var result bool + err := ss.db.Raw("SELECT ? "+test.cmp+" ?", bCurrency(test.a), bCurrency(test.b)).Scan(&result).Error + if err != nil { + t.Fatal(err) + } else if !result { + t.Fatal("unexpected result", result) + } + } + + c := func(c uint64) bCurrency { + return bCurrency(types.NewCurrency64(c)) + } + + var currencies []bCurrency + err := ss.db.Raw(` +WITH input(col) as +(values (?),(?),(?)) +SELECT * FROM input ORDER BY col ASC +`, c(3), c(1), c(2)).Scan(¤cies).Error + if err != nil { + t.Fatal(err) + } else if !sort.SliceIsSorted(currencies, func(i, j int) bool { + return types.Currency(currencies[i]).Cmp(types.Currency(currencies[j])) < 0 + }) { + t.Fatal("currencies not sorted", currencies) + } +} diff --git a/stores/types.go b/stores/types.go index 6b74f7563..9a7c72009 100644 --- a/stores/types.go +++ b/stores/types.go @@ -2,6 +2,7 @@ package stores import ( "database/sql/driver" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -25,6 +26,7 @@ type ( unixTimeMS time.Time datetime time.Time currency types.Currency + bCurrency types.Currency fileContractID types.FileContractID hash256 types.Hash256 publicKey types.PublicKey @@ -338,3 +340,29 @@ func (u *unsigned64) Scan(value interface{}) error { func (u unsigned64) Value() (driver.Value, error) { return int64(u), nil } + +func (bCurrency) GormDataType() string { + return "bytes" +} + +// Scan implements the sql.Scanner interface. +func (sc *bCurrency) Scan(src any) error { + buf, ok := src.([]byte) + if !ok { + return fmt.Errorf("cannot scan %T to Currency", src) + } else if len(buf) != 16 { + return fmt.Errorf("cannot scan %d bytes to Currency", len(buf)) + } + + sc.Lo = binary.LittleEndian.Uint64(buf[:8]) + sc.Hi = binary.LittleEndian.Uint64(buf[8:]) + return nil +} + +// Value implements the driver.Valuer interface. +func (sc bCurrency) Value() (driver.Value, error) { + buf := make([]byte, 16) + binary.LittleEndian.PutUint64(buf[:8], sc.Lo) + binary.LittleEndian.PutUint64(buf[8:], sc.Hi) + return buf, nil +}