diff --git a/consensus/merkle_test.go b/consensus/merkle_test.go new file mode 100644 index 00000000..4fe69626 --- /dev/null +++ b/consensus/merkle_test.go @@ -0,0 +1,39 @@ +package consensus + +import "testing" + +func TestElementAccumulatorEncoding(t *testing.T) { + for _, test := range []struct { + numLeaves uint64 + exp string + }{ + {0, `{"numLeaves":0,"trees":[]}`}, + {1, `{"numLeaves":1,"trees":["h:0000000000000000000000000000000000000000000000000000000000000000"]}`}, + {2, `{"numLeaves":2,"trees":["h:0100000000000000000000000000000000000000000000000000000000000000"]}`}, + {3, `{"numLeaves":3,"trees":["h:0000000000000000000000000000000000000000000000000000000000000000","h:0100000000000000000000000000000000000000000000000000000000000000"]}`}, + {10, `{"numLeaves":10,"trees":["h:0100000000000000000000000000000000000000000000000000000000000000","h:0300000000000000000000000000000000000000000000000000000000000000"]}`}, + {1 << 16, `{"numLeaves":65536,"trees":["h:1000000000000000000000000000000000000000000000000000000000000000"]}`}, + {1 << 32, `{"numLeaves":4294967296,"trees":["h:2000000000000000000000000000000000000000000000000000000000000000"]}`}, + } { + acc := ElementAccumulator{NumLeaves: test.numLeaves} + for i := range acc.Trees { + if acc.hasTreeAtHeight(i) { + acc.Trees[i][0] = byte(i) + } + } + js, err := acc.MarshalJSON() + if err != nil { + t.Fatal(err) + } + if string(js) != test.exp { + t.Errorf("expected %s, got %s", test.exp, js) + } + var acc2 ElementAccumulator + if err := acc2.UnmarshalJSON(js); err != nil { + t.Fatal(err) + } + if acc2 != acc { + t.Fatal("round trip failed: expected", acc, "got", acc2) + } + } +} diff --git a/consensus/update.go b/consensus/update.go index cce9e558..c1af5ad9 100644 --- a/consensus/update.go +++ b/consensus/update.go @@ -50,17 +50,17 @@ func (w *Work) UnmarshalText(b []byte) error { return nil } +// MarshalJSON implements json.Marshaler. +func (w Work) MarshalJSON() ([]byte, error) { + s, err := w.MarshalText() + return []byte(`"` + string(s) + `"`), err +} + // UnmarshalJSON implements json.Unmarshaler. func (w *Work) UnmarshalJSON(b []byte) error { return w.UnmarshalText(bytes.Trim(b, `"`)) } -// MarshalJSON implements json.Marshaler. -func (w Work) MarshalJSON() ([]byte, error) { - js, err := new(big.Int).SetBytes(w.n[:]).MarshalJSON() - return []byte(`"` + string(js) + `"`), err -} - func (w Work) add(v Work) Work { var r Work var sum, c uint64 diff --git a/consensus/update_test.go b/consensus/update_test.go index 948ca8f7..0ec4848b 100644 --- a/consensus/update_test.go +++ b/consensus/update_test.go @@ -3,6 +3,7 @@ package consensus_test import ( "encoding/json" "reflect" + "strings" "testing" "time" @@ -238,4 +239,92 @@ func TestApplyBlock(t *testing.T) { ru := consensus.RevertBlock(prev, b2, bs) dbStore.RevertBlock(cs, ru) checkRevertElements(ru, addedSCEs, spentSCEs, addedSFEs, spentSFEs) + + // reverting a non-child block should trigger a panic + func() { + defer func() { recover() }() + consensus.RevertBlock(cs, b2, bs) + t.Error("did not panic on reverting non-child block") + }() +} + +func TestWorkEncoding(t *testing.T) { + for _, test := range []struct { + val string + err bool + roundtrip string + }{ + {val: "0"}, + {val: "12345"}, + {val: "115792089237316195423570985008687907853269984665640564039457584007913129639935"}, // 1<<256 - 1 + {val: "01", roundtrip: "1"}, + {val: "-0", roundtrip: "0"}, + {err: true, val: ""}, + {err: true, val: "-1"}, + {err: true, val: " 1"}, + {err: true, val: "1 "}, + {err: true, val: "1157920892373161954235709850086879078532699846656405640394575840079131296399366"}, + {err: true, val: "not a number"}, + } { + for _, codec := range []struct { + name string + enc func(consensus.Work) (string, error) + dec func(string) (consensus.Work, error) + }{ + { + name: "String", + enc: func(w consensus.Work) (string, error) { + return w.String(), nil + }, + dec: func(s string) (w consensus.Work, err error) { + err = w.UnmarshalText([]byte(s)) + return + }, + }, + { + name: "MarshalText", + enc: func(w consensus.Work) (string, error) { + v, err := w.MarshalText() + return string(v), err + }, + dec: func(s string) (w consensus.Work, err error) { + err = w.UnmarshalText([]byte(s)) + return + }, + }, + { + name: "MarshalJSON", + enc: func(w consensus.Work) (string, error) { + v, err := w.MarshalJSON() + return strings.Trim(string(v), `"`), err + }, + dec: func(s string) (w consensus.Work, err error) { + err = w.UnmarshalJSON([]byte(strings.Trim(s, `"`))) + return + }, + }, + } { + w, err := codec.dec(test.val) + if err != nil { + if !test.err { + t.Errorf("%v: unexpected error for %v: %v", codec.name, test.val, err) + } + continue + } else if test.err { + t.Errorf("%v: expected error for %v, got nil", codec.name, test.val) + continue + } + exp := test.roundtrip + if exp == "" { + exp = test.val + } + got, err := codec.enc(w) + if err != nil { + t.Fatal(err) + } else if string(got) != exp { + t.Errorf("%v: %q failed roundtrip (got %q)", codec.name, test.val, got) + continue + } + } + } } diff --git a/consensus/validation_test.go b/consensus/validation_test.go index 54fc19d3..4ea35a0d 100644 --- a/consensus/validation_test.go +++ b/consensus/validation_test.go @@ -2,6 +2,7 @@ package consensus_test import ( "bytes" + "errors" "math" "testing" "time" @@ -1270,3 +1271,41 @@ func TestValidateV2Block(t *testing.T) { } } } + +func TestValidateV2Transaction(t *testing.T) { + type args struct { + ms *consensus.MidState + txn types.V2Transaction + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "failure - v2 transactions are not allowed", + args: args{ + ms: consensus.NewMidState(consensus.State{ + Index: types.ChainIndex{Height: uint64(0)}, + Network: &consensus.Network{ + HardforkV2: struct { + AllowHeight uint64 "json:\"allowHeight\"" + RequireHeight uint64 "json:\"requireHeight\"" + }{ + AllowHeight: uint64(2), + }, + }, + }), + txn: types.V2Transaction{}, + }, + wantErr: errors.New("v2 transactions are not allowed until v2 hardfork begins"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := consensus.ValidateV2Transaction(tt.args.ms, tt.args.txn); (err != nil) && err.Error() != tt.wantErr.Error() { + t.Errorf("ValidateV2Transaction() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}