Skip to content

Commit

Permalink
stores: extend TestProcessChainUpdate
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjan authored and ChrisSchinnerl committed Sep 26, 2024
1 parent ccf2c9e commit fac7b60
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 31 deletions.
165 changes: 138 additions & 27 deletions stores/chain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package stores
import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -31,7 +32,7 @@ func TestProcessChainUpdate(t *testing.T) {
}
fcid := fcids[0]

// assert contract state returns the correct state
// check current contract state
var state api.ContractState
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) (err error) {
state, err = tx.ContractState(fcid)
Expand All @@ -49,22 +50,14 @@ func TestProcessChainUpdate(t *testing.T) {
t.Fatalf("unexpected height %v", curr.Height)
}

// assert update chain index is successful
// run chain update
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
return tx.UpdateChainIndex(types.ChainIndex{Height: 1})
}); err != nil {
t.Fatal("unexpected error", err)
}

// check updated index
if curr, err := ss.ChainIndex(context.Background()); err != nil {
t.Fatal(err)
} else if curr.Height != 1 {
t.Fatalf("unexpected height %v", curr.Height)
}
// update chain index
if err := tx.UpdateChainIndex(types.ChainIndex{Height: 1}); err != nil {
return err
}

// assert update contract is successful
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
// update contract
if err := tx.UpdateContract(fcid, 1, 2, 3); err != nil {
return err
} else if err := tx.UpdateContractState(fcid, api.ContractStateActive); err != nil {
Expand All @@ -78,7 +71,14 @@ func TestProcessChainUpdate(t *testing.T) {
t.Fatal("unexpected error", err)
}

// assert contract was updated successfully
// assert updated index
if curr, err := ss.ChainIndex(context.Background()); err != nil {
t.Fatal(err)
} else if curr.Height != 1 {
t.Fatalf("unexpected height %v", curr.Height)
}

// assert updated contract
var we uint64
if c, err := ss.Contract(context.Background(), fcid); err != nil {
t.Fatal("unexpected error", err)
Expand Down Expand Up @@ -210,31 +210,41 @@ func TestProcessChainUpdate(t *testing.T) {
panic("oh no")
}

// assert we can revert spent outputs
now := time.Now().Round(time.Millisecond)
var ses []types.StateElement
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
index3 := types.ChainIndex{Height: 3}
index4 := types.ChainIndex{Height: 4}
created := []types.SiacoinElement{
{
StateElement: types.StateElement{},
SiacoinOutput: types.SiacoinOutput{},
StateElement: types.StateElement{
ID: types.Hash256{1},
LeafIndex: 1,
MerkleProof: []types.Hash256{{1}, {2}},
},
SiacoinOutput: types.SiacoinOutput{
Address: types.Address{1},
Value: types.NewCurrency64(1),
},
MaturityHeight: 100,
},
}
events := []wallet.Event{
{
Type: wallet.EventTypeV2Transaction,
Data: wallet.EventV2Transaction{},
},

// try spending non-existent output
err = tx.WalletApplyIndex(index4, nil, created, nil, time.Now())
if !errors.Is(err, sql.ErrOutputNotFound) {
return fmt.Errorf("expected ErrOutputNotFound, instead got: %w", err)
}

// create some elements
err := tx.WalletApplyIndex(index3, created, nil, events, time.Now())
// create the elements
err = tx.WalletApplyIndex(index3, created, nil, nil, time.Now())
if err != nil {
return err
}

// spend them
err = tx.WalletApplyIndex(index4, nil, created, events, time.Now())
err = tx.WalletApplyIndex(index4, nil, created, nil, time.Now())
if err != nil {
return err
}
Expand All @@ -250,8 +260,109 @@ func TestProcessChainUpdate(t *testing.T) {
if err != nil {
return err
}
return nil

// prepare event
events := []wallet.Event{
{
ID: types.Hash256{1},
Index: types.ChainIndex{Height: 5},
Type: wallet.EventTypeV2Transaction,
Data: wallet.EventV2Transaction{},
Timestamp: now,
},
}

// add them
err = tx.WalletApplyIndex(types.ChainIndex{Height: 5}, nil, nil, events, time.Now())
if err != nil {
return err
}

// fetch elements
ses, err = tx.WalletStateElements()
return err
}); err != nil {
t.Fatal("unexpected error", err)
}

// assert wallet state elements
if len(ses) != 1 {
t.Fatal("unexpected number of state elements", len(ses))
} else if se := ses[0]; se.ID != (types.Hash256{1}) {
t.Fatal("unexpected state element id", se.ID)
} else if se.LeafIndex != 1 {
t.Fatal("unexpected state element leaf index", se.LeafIndex)
} else if len(se.MerkleProof) != 2 {
t.Fatal("unexpected state element merkle proof", len(se.MerkleProof))
} else if se.MerkleProof[0] != (types.Hash256{1}) {
t.Fatal("unexpected state element merkle proof[0]", se.MerkleProof[0])
} else if se.MerkleProof[1] != (types.Hash256{2}) {
t.Fatal("unexpected state element merkle proof[1]", se.MerkleProof[1])
}

// update state elements
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
ses[0].LeafIndex = 2
ses[0].MerkleProof = []types.Hash256{{3}, {4}}
err := tx.UpdateWalletStateElements(ses)
if err != nil {
return err
}

ses, err = tx.WalletStateElements()
return err
}); err != nil {
t.Fatal("unexpected error", err)
}

// assert wallet state elements
if len(ses) != 1 {
t.Fatal("unexpected number of state elements", len(ses))
} else if se := ses[0]; se.LeafIndex != 2 {
t.Fatal("unexpected state element leaf index", se.LeafIndex)
} else if len(se.MerkleProof) != 2 {
t.Fatal("unexpected state element merkle proof length", len(se.MerkleProof))
} else if se.MerkleProof[0] != (types.Hash256{3}) {
t.Fatal("unexpected state element merkle proof[0]", se.MerkleProof[0])
} else if se.MerkleProof[1] != (types.Hash256{4}) {
t.Fatal("unexpected state element merkle proof[1]", se.MerkleProof[1])
}

// assert events
events, err := ss.WalletEvents(0, -1)
if err != nil {
t.Fatal(err)
} else if len(events) != 1 {
t.Fatal("unexpected number of events", len(events))
} else if events[0].Index.Height != 5 {
t.Fatal("unexpected event index height", events[0].Index.Height, events[0])
} else if events[0].Timestamp != now {
t.Fatal("unexpected event timestamp", events[0].Timestamp, now)
}

// revert the index and assert the event got removed
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
return tx.WalletRevertIndex(types.ChainIndex{Height: 5}, nil, nil, time.Now())
}); err != nil {
t.Fatal("expected error")
}
events, err = ss.WalletEvents(0, -1)
if err != nil {
t.Fatal(err)
} else if len(events) != 0 {
t.Fatal("unexpected number of events", len(events))
}

// assert we can't delete non-existing outputs when reverting
if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error {
return tx.WalletRevertIndex(types.ChainIndex{Height: 5}, []types.SiacoinElement{
{
StateElement: types.StateElement{ID: types.Hash256{2}},
SiacoinOutput: types.SiacoinOutput{},
MaturityHeight: 100,
},
}, nil, time.Now())
}); !errors.Is(err, sql.ErrOutputNotFound) {
t.Fatal("expected ErrOutputNotFound", err)
}
}
4 changes: 4 additions & 0 deletions stores/sql/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
"go.uber.org/zap"
)

var (
ErrOutputNotFound = errors.New("output not found")
)

var contractTables = []string{
"contracts",
"archived_contracts",
Expand Down
4 changes: 2 additions & 2 deletions stores/sql/mysql/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [
} else if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("failed to delete spent output: no rows affected")
return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound)
}
}
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen
} else if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("failed to delete removed output: no rows affected")
return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound)
}
}
}
Expand Down
10 changes: 8 additions & 2 deletions stores/sql/sqlite/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [
} else if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("failed to delete spent output: no rows affected")
return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound)
}
}
}
Expand Down Expand Up @@ -89,6 +89,12 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [

// insert new events
for _, e := range events {
if e.Index != index {
return fmt.Errorf("event index %v doesn't match index being applied %v", e.Index, index)
} else if e.ID == (types.Hash256{}) {
return fmt.Errorf("event id is required")
}

c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID)
data, err := json.Marshal(e.Data)
if err != nil {
Expand Down Expand Up @@ -137,7 +143,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen
} else if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("failed to delete removed output: no rows affected")
return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound)
}
}
}
Expand Down

0 comments on commit fac7b60

Please sign in to comment.