diff --git a/stores/chain_test.go b/stores/chain_test.go index 08b959ca90..3770887733 100644 --- a/stores/chain_test.go +++ b/stores/chain_test.go @@ -365,4 +365,19 @@ func TestProcessChainUpdate(t *testing.T) { }); !errors.Is(err, sql.ErrOutputNotFound) { t.Fatal("expected ErrOutputNotFound", err) } + + // assert we can't apply an index and pass events with mismatching index + if err := ss.ProcessChainUpdate(context.Background(), func(tx sql.ChainUpdateTx) error { + return tx.WalletApplyIndex(types.ChainIndex{Height: 5}, nil, nil, []wallet.Event{ + { + ID: types.Hash256{1}, + Index: types.ChainIndex{Height: 6}, + Type: wallet.EventTypeV2Transaction, + Data: wallet.EventV2Transaction{}, + Timestamp: now, + }, + }, now) + }); !errors.Is(err, sql.ErrIndexMissmatch) { + t.Fatal("expected ErrIndexMissmatch", err) + } } diff --git a/stores/sql/chain.go b/stores/sql/chain.go index 625ce5d7ea..bf697dce52 100644 --- a/stores/sql/chain.go +++ b/stores/sql/chain.go @@ -13,6 +13,7 @@ import ( ) var ( + ErrIndexMissmatch = errors.New("index missmatch") ErrOutputNotFound = errors.New("output not found") ) diff --git a/stores/sql/mysql/chain.go b/stores/sql/mysql/chain.go index 73cd05482a..5cfe536068 100644 --- a/stores/sql/mysql/chain.go +++ b/stores/sql/mysql/chain.go @@ -44,7 +44,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ if res, err := deleteSpentStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { return fmt.Errorf("failed to delete spent output: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete spent output: %w", err) } else if n != 1 { return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound) } @@ -86,6 +86,14 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ // insert new events for _, e := range events { + if e.Index != index { + return fmt.Errorf("%w, event index %v != applied index %v", ssql.ErrIndexMissmatch, e.Index, index) + } else if e.ID == (types.Hash256{}) { + return fmt.Errorf("event id is required") + } else if e.Timestamp.IsZero() { + return fmt.Errorf("event timestamp is required") + } + c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID) data, err := json.Marshal(e.Data) if err != nil { @@ -132,7 +140,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen if res, err := deleteRemovedStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { return fmt.Errorf("failed to delete removed output: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete removed output: %w", err) } else if n != 1 { return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound) } @@ -169,7 +177,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen if err != nil { return fmt.Errorf("failed to delete events: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete events: %w", err) } else if n > 0 { c.l.Debugw(fmt.Sprintf("removed %d events", n), "height", index.Height, "block_id", index.ID) } diff --git a/stores/sql/sqlite/chain.go b/stores/sql/sqlite/chain.go index 95d16434c5..4fa869ad66 100644 --- a/stores/sql/sqlite/chain.go +++ b/stores/sql/sqlite/chain.go @@ -47,7 +47,7 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ if res, err := deleteSpentStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { return fmt.Errorf("failed to delete spent output: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete spent output: %w", err) } else if n != 1 { return fmt.Errorf("failed to delete spent output: %w", ssql.ErrOutputNotFound) } @@ -90,9 +90,11 @@ func (c chainUpdateTx) WalletApplyIndex(index types.ChainIndex, created, spent [ // insert new events for _, e := range events { if e.Index != index { - return fmt.Errorf("event index %v doesn't match index being applied %v", e.Index, index) + return fmt.Errorf("%w, event index %v != applied index %v", ssql.ErrIndexMissmatch, e.Index, index) } else if e.ID == (types.Hash256{}) { return fmt.Errorf("event id is required") + } else if e.Timestamp.IsZero() { + return fmt.Errorf("event timestamp is required") } c.l.Debugw(fmt.Sprintf("create event %v", e.ID), "height", index.Height, "block_id", index.ID) @@ -141,7 +143,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen if res, err := deleteRemovedStmt.Exec(c.ctx, ssql.Hash256(e.ID)); err != nil { return fmt.Errorf("failed to delete removed output: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete removed output: %w", err) } else if n != 1 { return fmt.Errorf("failed to delete removed output: %w", ssql.ErrOutputNotFound) } @@ -178,7 +180,7 @@ func (c chainUpdateTx) WalletRevertIndex(index types.ChainIndex, removed, unspen if err != nil { return fmt.Errorf("failed to delete events: %w", err) } else if n, err := res.RowsAffected(); err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) + return fmt.Errorf("failed to delete events: %w", err) } else if n > 0 { c.l.Debugw(fmt.Sprintf("removed %d events", n), "height", index.Height, "block_id", index.ID) }