diff --git a/stores/chain_test.go b/stores/chain_test.go index 615e27098..08b959ca9 100644 --- a/stores/chain_test.go +++ b/stores/chain_test.go @@ -3,6 +3,7 @@ package stores import ( "context" "errors" + "fmt" "strings" "testing" "time" @@ -31,7 +32,7 @@ func TestProcessChainUpdate(t *testing.T) { } fcid := fcids[0] - // assert contract state returns the correct state + // check current contract state var state api.ContractState if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) (err error) { state, err = tx.ContractState(fcid) @@ -49,22 +50,14 @@ func TestProcessChainUpdate(t *testing.T) { t.Fatalf("unexpected height %v", curr.Height) } - // assert update chain index is successful + // run chain update if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { - return tx.UpdateChainIndex(types.ChainIndex{Height: 1}) - }); err != nil { - t.Fatal("unexpected error", err) - } - - // check updated index - if curr, err := ss.ChainIndex(context.Background()); err != nil { - t.Fatal(err) - } else if curr.Height != 1 { - t.Fatalf("unexpected height %v", curr.Height) - } + // update chain index + if err := tx.UpdateChainIndex(types.ChainIndex{Height: 1}); err != nil { + return err + } - // assert update contract is successful - if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + // update contract if err := tx.UpdateContract(fcid, 1, 2, 3); err != nil { return err } else if err := tx.UpdateContractState(fcid, api.ContractStateActive); err != nil { @@ -78,7 +71,14 @@ func TestProcessChainUpdate(t *testing.T) { t.Fatal("unexpected error", err) } - // assert contract was updated successfully + // assert updated index + if curr, err := ss.ChainIndex(context.Background()); err != nil { + t.Fatal(err) + } else if curr.Height != 1 { + t.Fatalf("unexpected height %v", curr.Height) + } + + // assert updated contract var we uint64 if c, err := ss.Contract(context.Background(), fcid); err != nil { t.Fatal("unexpected error", err) @@ -210,31 +210,41 @@ func TestProcessChainUpdate(t *testing.T) { panic("oh no") } + // assert we can revert spent outputs + now := time.Now().Round(time.Millisecond) + var ses []types.StateElement if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { index3 := types.ChainIndex{Height: 3} index4 := types.ChainIndex{Height: 4} created := []types.SiacoinElement{ { - StateElement: types.StateElement{}, - SiacoinOutput: types.SiacoinOutput{}, + StateElement: types.StateElement{ + ID: types.Hash256{1}, + LeafIndex: 1, + MerkleProof: []types.Hash256{{1}, {2}}, + }, + SiacoinOutput: types.SiacoinOutput{ + Address: types.Address{1}, + Value: types.NewCurrency64(1), + }, MaturityHeight: 100, }, } - events := []wallet.Event{ - { - Type: wallet.EventTypeV2Transaction, - Data: wallet.EventV2Transaction{}, - }, + + // try spending non-existent output + err = tx.WalletApplyIndex(index4, nil, created, nil, time.Now()) + if !errors.Is(err, sql.ErrOutputNotFound) { + return fmt.Errorf("expected ErrOutputNotFound, instead got: %w", err) } - // create some elements - err := tx.WalletApplyIndex(index3, created, nil, events, time.Now()) + // create the elements + err = tx.WalletApplyIndex(index3, created, nil, nil, time.Now()) if err != nil { return err } // spend them - err = tx.WalletApplyIndex(index4, nil, created, events, time.Now()) + err = tx.WalletApplyIndex(index4, nil, created, nil, time.Now()) if err != nil { return err } @@ -250,8 +260,109 @@ func TestProcessChainUpdate(t *testing.T) { if err != nil { return err } - return nil + + // prepare event + events := []wallet.Event{ + { + ID: types.Hash256{1}, + Index: types.ChainIndex{Height: 5}, + Type: wallet.EventTypeV2Transaction, + Data: wallet.EventV2Transaction{}, + Timestamp: now, + }, + } + + // add them + err = tx.WalletApplyIndex(types.ChainIndex{Height: 5}, nil, nil, events, time.Now()) + if err != nil { + return err + } + + // fetch elements + ses, err = tx.WalletStateElements() + return err + }); err != nil { + t.Fatal("unexpected error", err) + } + + // assert wallet state elements + if len(ses) != 1 { + t.Fatal("unexpected number of state elements", len(ses)) + } else if se := ses[0]; se.ID != (types.Hash256{1}) { + t.Fatal("unexpected state element id", se.ID) + } else if se.LeafIndex != 1 { + t.Fatal("unexpected state element leaf index", se.LeafIndex) + } else if len(se.MerkleProof) != 2 { + t.Fatal("unexpected state element merkle proof", len(se.MerkleProof)) + } else if se.MerkleProof[0] != (types.Hash256{1}) { + t.Fatal("unexpected state element merkle proof[0]", se.MerkleProof[0]) + } else if se.MerkleProof[1] != (types.Hash256{2}) { + t.Fatal("unexpected state element merkle proof[1]", se.MerkleProof[1]) + } + + // update state elements + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + ses[0].LeafIndex = 2 + ses[0].MerkleProof = []types.Hash256{{3}, {4}} + err := tx.UpdateWalletStateElements(ses) + if err != nil { + return err + } + + ses, err = tx.WalletStateElements() + return err }); err != nil { t.Fatal("unexpected error", err) } + + // assert wallet state elements + if len(ses) != 1 { + t.Fatal("unexpected number of state elements", len(ses)) + } else if se := ses[0]; se.LeafIndex != 2 { + t.Fatal("unexpected state element leaf index", se.LeafIndex) + } else if len(se.MerkleProof) != 2 { + t.Fatal("unexpected state element merkle proof length", len(se.MerkleProof)) + } else if se.MerkleProof[0] != (types.Hash256{3}) { + t.Fatal("unexpected state element merkle proof[0]", se.MerkleProof[0]) + } else if se.MerkleProof[1] != (types.Hash256{4}) { + t.Fatal("unexpected state element merkle proof[1]", se.MerkleProof[1]) + } + + // assert events + events, err := ss.WalletEvents(0, -1) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatal("unexpected number of events", len(events)) + } else if events[0].Index.Height != 5 { + t.Fatal("unexpected event index height", events[0].Index.Height, events[0]) + } else if events[0].Timestamp != now { + t.Fatal("unexpected event timestamp", events[0].Timestamp, now) + } + + // revert the index and assert the event got removed + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.WalletRevertIndex(types.ChainIndex{Height: 5}, nil, nil, time.Now()) + }); err != nil { + t.Fatal("expected error") + } + events, err = ss.WalletEvents(0, -1) + if err != nil { + t.Fatal(err) + } else if len(events) != 0 { + t.Fatal("unexpected number of events", len(events)) + } + + // assert we can't delete non-existing outputs when reverting + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.WalletRevertIndex(types.ChainIndex{Height: 5}, []types.SiacoinElement{ + { + StateElement: types.StateElement{ID: types.Hash256{2}}, + SiacoinOutput: types.SiacoinOutput{}, + MaturityHeight: 100, + }, + }, nil, time.Now()) + }); !errors.Is(err, sql.ErrOutputNotFound) { + t.Fatal("expected ErrOutputNotFound", err) + } } diff --git a/stores/sql/chain.go b/stores/sql/chain.go index c9c991d4e..625ce5d7e 100644 --- a/stores/sql/chain.go +++ b/stores/sql/chain.go @@ -12,6 +12,10 @@ import ( "go.uber.org/zap" ) +var ( + ErrOutputNotFound = errors.New("output not found") +) + var contractTables = []string{ "contracts", "archived_contracts", diff --git a/stores/sql/mysql/chain.go b/stores/sql/mysql/chain.go index 4e5720c9e..73cd05482 100644 --- a/stores/sql/mysql/chain.go +++ b/stores/sql/mysql/chain.go @@ -46,7 +46,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ } else if n, err := res.RowsAffected(); err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } else if n != 1 { - return fmt.Errorf("failed to delete spent output: no rows affected") + return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound) } } } @@ -134,7 +134,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen } else if n, err := res.RowsAffected(); err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } else if n != 1 { - return fmt.Errorf("failed to delete removed output: no rows affected") + return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound) } } } diff --git a/stores/sql/sqlite/chain.go b/stores/sql/sqlite/chain.go index 0ec937b73..95d16434c 100644 --- a/stores/sql/sqlite/chain.go +++ b/stores/sql/sqlite/chain.go @@ -49,7 +49,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ } else if n, err := res.RowsAffected(); err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } else if n != 1 { - return fmt.Errorf("failed to delete spent output: no rows affected") + return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound) } } } @@ -89,6 +89,12 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ // insert new events for _, e := range events { + if e.Index != index { + return fmt.Errorf("event index %v doesn't match index being applied %v", e.Index, index) + } else if e.ID == (types.Hash256{}) { + return fmt.Errorf("event id is required") + } + c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID) data, err := json.Marshal(e.Data) if err != nil { @@ -137,7 +143,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen } else if n, err := res.RowsAffected(); err != nil { return fmt.Errorf("failed to get rows affected: %w", err) } else if n != 1 { - return fmt.Errorf("failed to delete removed output: no rows affected") + return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound) } } }