diff --git a/changelog.md b/changelog.md index 8bfbae13f7..3f8206a9f1 100644 --- a/changelog.md +++ b/changelog.md @@ -18,6 +18,7 @@ * [2919](https://github.com/zeta-chain/node/pull/2919) - add inbound sender to revert context * [2957](https://github.com/zeta-chain/node/pull/2957) - enable Bitcoin inscription support on testnet * [2896](https://github.com/zeta-chain/node/pull/2896) - add TON inbound observation +* [2987](https://github.com/zeta-chain/node/pull/2987) - add non-EVM standard inbound memo package ### Refactor diff --git a/pkg/chains/conversion.go b/pkg/chains/conversion.go index 3200b76693..5d03b77007 100644 --- a/pkg/chains/conversion.go +++ b/pkg/chains/conversion.go @@ -1,10 +1,8 @@ package chains import ( - "encoding/hex" "fmt" - "cosmossdk.io/errors" "github.com/btcsuite/btcd/chaincfg/chainhash" ethcommon "github.com/ethereum/go-ethereum/common" ) @@ -28,26 +26,3 @@ func StringToHash(chainID int64, hash string, additionalChains []Chain) ([]byte, } return nil, fmt.Errorf("cannot convert hash to bytes for chain %d", chainID) } - -// ParseAddressAndData parses the message string into an address and data -// message is hex encoded byte array -// [ contractAddress calldata ] -// [ 20B, variable] -func ParseAddressAndData(message string) (ethcommon.Address, []byte, error) { - if len(message) == 0 { - return ethcommon.Address{}, nil, nil - } - - data, err := hex.DecodeString(message) - if err != nil { - return ethcommon.Address{}, nil, errors.Wrap(err, "message should be a hex encoded string") - } - - if len(data) < 20 { - return ethcommon.Address{}, data, nil - } - - address := ethcommon.BytesToAddress(data[:20]) - data = data[20:] - return address, data, nil -} diff --git a/pkg/chains/utils_test.go b/pkg/chains/utils_test.go index 915fa2b8ca..7552e10967 100644 --- a/pkg/chains/utils_test.go +++ b/pkg/chains/utils_test.go @@ -1,7 +1,6 @@ package chains import ( - "encoding/hex" "testing" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -66,39 +65,3 @@ func TestStringToHash(t *testing.T) { }) } } - -func TestParseAddressAndData(t *testing.T) { - expectedShortMsgResult, err := hex.DecodeString("1a2b3c4d5e6f708192a3b4c5d6e7f808") - require.NoError(t, err) - tests := []struct { - name string - message string - expectAddr ethcommon.Address - expectData []byte - wantErr bool - }{ - { - "valid msg", - "95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5", - ethcommon.HexToAddress("95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5"), - []byte{}, - false, - }, - {"empty msg", "", ethcommon.Address{}, nil, false}, - {"invalid hex", "invalidHex", ethcommon.Address{}, nil, true}, - {"short msg", "1a2b3c4d5e6f708192a3b4c5d6e7f808", ethcommon.Address{}, expectedShortMsgResult, false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - addr, data, err := ParseAddressAndData(tt.message) - if tt.wantErr { - require.Error(t, err) - } else { - require.NoError(t, err) - require.Equal(t, tt.expectAddr, addr) - require.Equal(t, tt.expectData, data) - } - }) - } -} diff --git a/pkg/math/bits/bits.go b/pkg/math/bits/bits.go new file mode 100644 index 0000000000..a7a7865793 --- /dev/null +++ b/pkg/math/bits/bits.go @@ -0,0 +1,56 @@ +package math + +import ( + "math/bits" +) + +// SetBit sets the bit at the given position (0-7) in the byte to 1 +func SetBit(b *byte, position uint8) { + if position > 7 { + return + } + + // Example: given b = 0b00000000 and position = 3 + // step-1: shift value 1 to left by 3 times: 1 << 3 = 0b00001000 + // step-2: make an OR operation with original byte to set the bit: 0b00000000 | 0b00001000 = 0b00001000 + *b |= 1 << position +} + +// IsBitSet returns true if the bit at the given position (0-7) is set in the byte, false otherwise +func IsBitSet(b byte, position uint8) bool { + if position > 7 { + return false + } + bitMask := byte(1 << position) + return b&bitMask != 0 +} + +// GetBits extracts the value of bits for a given mask +// +// Example: given b = 0b11011001 and mask = 0b11100000, the function returns 0b110 +func GetBits(b byte, mask byte) byte { + extracted := b & mask + + // get the number of trailing zero bits + trailingZeros := bits.TrailingZeros8(mask) + + // remove trailing zeros + return extracted >> trailingZeros +} + +// SetBits sets the value to the bits specified in the mask +// +// Example: given b = 0b00100001 and mask = 0b11100000, and value = 0b110, the function returns 0b11000001 +func SetBits(b byte, mask byte, value byte) byte { + // get the number of trailing zero bits in the mask + trailingZeros := bits.TrailingZeros8(mask) + + // shift the value left by the number of trailing zeros + valueShifted := value << trailingZeros + + // clear the bits in 'b' that correspond to the mask + bCleared := b &^ mask + + // Set the bits by ORing the cleared 'b' with the shifted value + return bCleared | valueShifted +} diff --git a/pkg/math/bits/bits_test.go b/pkg/math/bits/bits_test.go new file mode 100644 index 0000000000..65632d0a24 --- /dev/null +++ b/pkg/math/bits/bits_test.go @@ -0,0 +1,171 @@ +package math_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + zetabits "github.com/zeta-chain/node/pkg/math/bits" +) + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + initial byte + position uint8 + expected byte + }{ + { + name: "set bit at position 0", + initial: 0b00001000, + position: 0, + expected: 0b00001001, + }, + { + name: "set bit at position 7", + initial: 0b00001000, + position: 7, + expected: 0b10001000, + }, + { + name: "out of range bit position (no effect)", + initial: 0b00000000, + position: 8, // Out of range + expected: 0b00000000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := tt.initial + zetabits.SetBit(&b, tt.position) + require.Equal(t, tt.expected, b) + }) + } +} + +func TestIsBitSet(t *testing.T) { + tests := []struct { + name string + b byte + position uint8 + expected bool + }{ + { + name: "bit 0 set", + b: 0b00000001, + position: 0, + expected: true, + }, + { + name: "bit 7 set", + b: 0b10000000, + position: 7, + expected: true, + }, + { + name: "bit 2 not set", + b: 0b00000001, + position: 2, + expected: false, + }, + { + name: "bit out of range", + b: 0b00000001, + position: 8, // Position out of range + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := zetabits.IsBitSet(tt.b, tt.position) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestGetBits(t *testing.T) { + tests := []struct { + name string + b byte + mask byte + expected byte + }{ + { + name: "extract upper 3 bits", + b: 0b11011001, + mask: 0b11100000, + expected: 0b110, + }, + { + name: "extract middle 3 bits", + b: 0b11011001, + mask: 0b00011100, + expected: 0b110, + }, + { + name: "extract lower 3 bits", + b: 0b11011001, + mask: 0b00000111, + expected: 0b001, + }, + { + name: "extract no bits", + b: 0b11011001, + mask: 0b00000000, + expected: 0b000, + }, + { + name: "extract all bits", + b: 0b11111111, + mask: 0b11111111, + expected: 0b11111111, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := zetabits.GetBits(tt.b, tt.mask) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestSetBits(t *testing.T) { + tests := []struct { + name string + b byte + mask byte + value byte + expected byte + }{ + { + name: "set upper 3 bits", + b: 0b00100001, + mask: 0b11100000, + value: 0b110, + expected: 0b11000001, + }, + { + name: "set middle 3 bits", + b: 0b00100001, + mask: 0b00011100, + value: 0b101, + expected: 0b00110101, + }, + { + name: "set lower 3 bits", + b: 0b11111100, + mask: 0b00000111, + value: 0b101, + expected: 0b11111101, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := zetabits.SetBits(tt.b, tt.mask, tt.value) + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/memo/arg.go b/pkg/memo/arg.go new file mode 100644 index 0000000000..e29615f562 --- /dev/null +++ b/pkg/memo/arg.go @@ -0,0 +1,52 @@ +package memo + +// ArgType is the enum for types supported by the codec +type ArgType string + +// Define all the types supported by the codec +const ( + ArgTypeBytes ArgType = "bytes" + ArgTypeString ArgType = "string" + ArgTypeAddress ArgType = "address" +) + +// CodecArg represents a codec argument +type CodecArg struct { + Name string + Type ArgType + Arg interface{} +} + +// NewArg create a new codec argument +func NewArg(name string, argType ArgType, arg interface{}) CodecArg { + return CodecArg{ + Name: name, + Type: argType, + Arg: arg, + } +} + +// ArgReceiver wraps the receiver address in a CodecArg +func ArgReceiver(arg interface{}) CodecArg { + return NewArg("receiver", ArgTypeAddress, arg) +} + +// ArgPayload wraps the payload in a CodecArg +func ArgPayload(arg interface{}) CodecArg { + return NewArg("payload", ArgTypeBytes, arg) +} + +// ArgRevertAddress wraps the revert address in a CodecArg +func ArgRevertAddress(arg interface{}) CodecArg { + return NewArg("revertAddress", ArgTypeString, arg) +} + +// ArgAbortAddress wraps the abort address in a CodecArg +func ArgAbortAddress(arg interface{}) CodecArg { + return NewArg("abortAddress", ArgTypeAddress, arg) +} + +// ArgRevertMessage wraps the revert message in a CodecArg +func ArgRevertMessage(arg interface{}) CodecArg { + return NewArg("revertMessage", ArgTypeBytes, arg) +} diff --git a/pkg/memo/arg_test.go b/pkg/memo/arg_test.go new file mode 100644 index 0000000000..ee9373e688 --- /dev/null +++ b/pkg/memo/arg_test.go @@ -0,0 +1,70 @@ +package memo_test + +import ( + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +func Test_NewArg(t *testing.T) { + argAddress := common.HexToAddress("0x0B85C56e5453e0f4273d1D1BF3091d43B08B38CE") + argString := "some other string argument" + argBytes := []byte("here is a bytes argument") + + tests := []struct { + name string + argType string + arg interface{} + }{ + { + name: "receiver", + argType: "address", + arg: &argAddress, + }, + { + name: "payload", + argType: "bytes", + arg: &argBytes, + }, + { + name: "revertAddress", + argType: "string", + arg: &argString, + }, + { + name: "abortAddress", + argType: "address", + arg: &argAddress, + }, + { + name: "revertMessage", + argType: "bytes", + arg: &argBytes, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arg := memo.NewArg(tt.name, memo.ArgType(tt.argType), tt.arg) + + require.Equal(t, tt.name, arg.Name) + require.Equal(t, memo.ArgType(tt.argType), arg.Type) + require.Equal(t, tt.arg, arg.Arg) + + switch tt.name { + case "receiver": + require.Equal(t, arg, memo.ArgReceiver(&argAddress)) + case "payload": + require.Equal(t, arg, memo.ArgPayload(&argBytes)) + case "revertAddress": + require.Equal(t, arg, memo.ArgRevertAddress(&argString)) + case "abortAddress": + require.Equal(t, arg, memo.ArgAbortAddress(&argAddress)) + case "revertMessage": + require.Equal(t, arg, memo.ArgRevertMessage(&argBytes)) + } + }) + } +} diff --git a/pkg/memo/codec.go b/pkg/memo/codec.go new file mode 100644 index 0000000000..ea60fa24c1 --- /dev/null +++ b/pkg/memo/codec.go @@ -0,0 +1,64 @@ +package memo + +import ( + "fmt" +) + +type EncodingFormat uint8 + +// Enum for non-EVM chain memo encoding format (2 bits) +const ( + // EncodingFmtABI represents ABI encoding format + EncodingFmtABI EncodingFormat = 0b0000 + + // EncodingFmtCompactShort represents 'compact short' encoding format + EncodingFmtCompactShort EncodingFormat = 0b0001 + + // EncodingFmtCompactLong represents 'compact long' encoding format + EncodingFmtCompactLong EncodingFormat = 0b0010 + + // EncodingFmtInvalid represents invalid encoding format + EncodingFmtInvalid EncodingFormat = 0b0011 +) + +// Enum for length of bytes used to encode compact data +const ( + LenBytesShort = 1 + LenBytesLong = 2 +) + +// Codec is the interface for a codec +type Codec interface { + // AddArguments adds a list of arguments to the codec + AddArguments(args ...CodecArg) + + // PackArguments packs the arguments into the encoded data + PackArguments() ([]byte, error) + + // UnpackArguments unpacks the encoded data into the arguments + UnpackArguments(data []byte) error +} + +// GetLenBytes returns the number of bytes used to encode the length of the data +func GetLenBytes(encodingFmt EncodingFormat) (int, error) { + switch encodingFmt { + case EncodingFmtCompactShort: + return LenBytesShort, nil + case EncodingFmtCompactLong: + return LenBytesLong, nil + default: + return 0, fmt.Errorf("invalid compact encoding format %d", encodingFmt) + } +} + +// GetCodec returns the codec based on the encoding format +func GetCodec(encodingFmt EncodingFormat) (Codec, error) { + switch encodingFmt { + case EncodingFmtABI: + return NewCodecABI(), nil + case EncodingFmtCompactShort, EncodingFmtCompactLong: + return NewCodecCompact(encodingFmt) + default: + return nil, fmt.Errorf("invalid encoding format %d", encodingFmt) + } +} diff --git a/pkg/memo/codec_abi.go b/pkg/memo/codec_abi.go new file mode 100644 index 0000000000..708e1a6398 --- /dev/null +++ b/pkg/memo/codec_abi.go @@ -0,0 +1,94 @@ +package memo + +import ( + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/accounts/abi" + "github.com/pkg/errors" +) + +const ( + // selectorLength is the length of the selector in bytes + selectorLength = 4 + + // codecMethod is the name of the codec method + codecMethod = "codec" + + // codecMethodABIString is the ABI string template for codec method + codecMethodABIString = `[{"name":"codec", "inputs":[%s], "outputs":[%s], "type":"function"}]` +) + +var _ Codec = (*CodecABI)(nil) + +// CodecABI is a coder/decoder for ABI encoded memo fields +type CodecABI struct { + // abiTypes contains the ABI types of the arguments + abiTypes []string + + // abiArgs contains the ABI arguments to be packed or unpacked into + abiArgs []interface{} +} + +// NewCodecABI creates a new ABI codec +func NewCodecABI() *CodecABI { + return &CodecABI{ + abiTypes: make([]string, 0), + abiArgs: make([]interface{}, 0), + } +} + +// AddArguments adds a list of arguments to the codec +func (c *CodecABI) AddArguments(args ...CodecArg) { + for _, arg := range args { + typeJSON := fmt.Sprintf(`{"type":"%s"}`, arg.Type) + c.abiTypes = append(c.abiTypes, typeJSON) + c.abiArgs = append(c.abiArgs, arg.Arg) + } +} + +// PackArguments packs the arguments into the ABI encoded data +func (c *CodecABI) PackArguments() ([]byte, error) { + // get parsed ABI based on the inputs + parsedABI, err := c.parsedABI() + if err != nil { + return nil, errors.Wrap(err, "failed to parse ABI string") + } + + // pack the arguments + data, err := parsedABI.Pack(codecMethod, c.abiArgs...) + if err != nil { + return nil, errors.Wrap(err, "failed to pack ABI arguments") + } + + // this never happens + if len(data) < selectorLength { + return nil, errors.New("packed data less than selector length") + } + + return data[selectorLength:], nil +} + +// UnpackArguments unpacks the ABI encoded data into the output arguments +func (c *CodecABI) UnpackArguments(data []byte) error { + // get parsed ABI based on the inputs + parsedABI, err := c.parsedABI() + if err != nil { + return errors.Wrap(err, "failed to parse ABI string") + } + + // unpack data into outputs + err = parsedABI.UnpackIntoInterface(&c.abiArgs, codecMethod, data) + if err != nil { + return errors.Wrap(err, "failed to unpack ABI encoded data") + } + + return nil +} + +// parsedABI builds a parsed ABI based on the inputs +func (c *CodecABI) parsedABI() (abi.ABI, error) { + typeList := strings.Join(c.abiTypes, ",") + abiString := fmt.Sprintf(codecMethodABIString, typeList, typeList) + return abi.JSON(strings.NewReader(abiString)) +} diff --git a/pkg/memo/codec_abi_test.go b/pkg/memo/codec_abi_test.go new file mode 100644 index 0000000000..c4a2e1decd --- /dev/null +++ b/pkg/memo/codec_abi_test.go @@ -0,0 +1,345 @@ +package memo_test + +import ( + "bytes" + "encoding/binary" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +const ( + // abiAlignment is the number of bytes used to align the ABI encoded data + abiAlignment = 32 +) + +// ABIPack is a helper function that simulates the abi.Pack function. +// Note: all arguments are assumed to be <= 32 bytes for simplicity. +func ABIPack(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // data offset for 1st dynamic-length field + offset := abiAlignment * len(args) + + // 1. pack 32-byte offset for each dynamic-length field (bytes, string) + // 2. pack actual data for each fixed-length field (address) + for _, arg := range args { + switch arg.Type { + case memo.ArgTypeBytes: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) + packedData = append(packedData, offsetData...) + + argLen := len(arg.Arg.([]byte)) + if argLen > 0 { + offset += abiAlignment * 2 // [length + data] + } else { + offset += abiAlignment // only [length] + } + + case memo.ArgTypeString: + // left-pad length as uint16 + buff := make([]byte, 2) + binary.BigEndian.PutUint16(buff, uint16(offset)) + offsetData := abiPad32(t, buff, true) + packedData = append(packedData, offsetData...) + + argLen := len([]byte(arg.Arg.(string))) + if argLen > 0 { + offset += abiAlignment * 2 // [length + data] + } else { + offset += abiAlignment // only [length] + } + + case memo.ArgTypeAddress: // left-pad for address + data := abiPad32(t, arg.Arg.(common.Address).Bytes(), true) + packedData = append(packedData, data...) + } + } + + // pack dynamic-length fields + dynamicData := abiPackDynamicData(t, args...) + packedData = append(packedData, dynamicData...) + + return packedData +} + +// abiPad32 is a helper function to pad a byte slice to 32 bytes +func abiPad32(t *testing.T, data []byte, left bool) []byte { + // nothing needs to be encoded, return empty bytes + if len(data) == 0 { + return []byte{} + } + + require.LessOrEqual(t, len(data), abiAlignment) + padded := make([]byte, 32) + + if left { + // left-pad the data for fixed-size types + copy(padded[32-len(data):], data) + } else { + // right-pad the data for dynamic types + copy(padded, data) + } + return padded +} + +// abiPackDynamicData is a helper function to pack dynamic-length data +func abiPackDynamicData(t *testing.T, args ...memo.CodecArg) []byte { + packedData := make([]byte, 0) + + // pack with ABI format: length + data + for _, arg := range args { + // get length + var length int + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + continue + } + + // append length in bytes + lengthData := abiPad32(t, []byte{byte(length)}, true) + packedData = append(packedData, lengthData...) + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: // right-pad for bytes + data := abiPad32(t, arg.Arg.([]byte), false) + packedData = append(packedData, data...) + case memo.ArgTypeString: // right-pad for string + data := abiPad32(t, []byte(arg.Arg.(string)), false) + packedData = append(packedData, data...) + } + } + + return packedData +} + +// newArgInstance creates a new instance of the given argument type +func newArgInstance(v interface{}) interface{} { + switch v.(type) { + case common.Address: + return new(common.Address) + case []byte: + return &[]byte{} + case string: + return new(string) + } + return nil +} + +// ensureArgEquality ensures the expected argument and actual value are equal +func ensureArgEquality(t *testing.T, expected, actual interface{}) { + switch v := expected.(type) { + case common.Address: + require.Equal(t, v.Hex(), actual.(*common.Address).Hex()) + case []byte: + require.True(t, bytes.Equal(v, *actual.(*[]byte))) + case string: + require.Equal(t, v, *actual.(*string)) + default: + require.FailNow(t, "unexpected argument type", "Type: %T", v) + } +} + +func Test_NewCodecABI(t *testing.T) { + c := memo.NewCodecABI() + require.NotNil(t, c) +} + +func Test_CodecABI_AddArguments(t *testing.T) { + codec := memo.NewCodecABI() + require.NotNil(t, codec) + + address := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") + codec.AddArguments(memo.ArgReceiver(&address)) + + // attempt to pack the arguments, result should not be nil + packedData, err := codec.PackArguments() + require.NoError(t, err) + require.True(t, len(packedData) > 0) +} + +func Test_CodecABI_PackArgument(t *testing.T) { + // create sample arguments + argAddress := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") + argBytes := []byte("some test bytes argument") + argString := "some test string argument" + + // test cases + tests := []struct { + name string + args []memo.CodecArg + errMsg string + }{ + { + name: "pack in the order of [address, bytes, string]", + args: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "pack in the order of [string, address, bytes]", + args: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "pack empty bytes array and string", + args: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "unable to parse unsupported ABI type", + args: []memo.CodecArg{ + memo.ArgReceiver(&argAddress), + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "failed to parse ABI string", + }, + { + name: "packing should fail on argument type mismatch", + args: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "failed to pack ABI arguments", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new ABI codec and add arguments + codec := memo.NewCodecABI() + codec.AddArguments(tc.args...) + + // pack arguments into ABI-encoded packedData + packedData, err := codec.PackArguments() + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + require.Nil(t, packedData) + return + } + require.NoError(t, err) + + // calc expected data for comparison + expectedData := ABIPack(t, tc.args...) + + // validate the packed data + require.True(t, bytes.Equal(expectedData, packedData), "ABI encoded data mismatch") + }) + } +} + +func Test_CodecABI_UnpackArguments(t *testing.T) { + // create sample arguments + argAddress := common.HexToAddress("0xEf221eC80f004E6A2ee4E5F5d800699c1C68cD6F") + argBytes := []byte("some test bytes argument") + argString := "some test string argument" + + // test cases + tests := []struct { + name string + data []byte + expected []memo.CodecArg + errMsg string + }{ + { + name: "unpack in the order of [address, bytes, string]", + data: ABIPack(t, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString)), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "unpack in the order of [string, address, bytes]", + data: ABIPack(t, + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes)), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "unpack empty bytes array and string", + data: ABIPack(t, + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress("")), + expected: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "unable to parse unsupported ABI type", + data: []byte{}, + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "failed to parse ABI string", + }, + { + name: "unpacking should fail on argument type mismatch", + data: ABIPack(t, + memo.ArgReceiver(argAddress), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "failed to unpack ABI encoded data", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new ABI codec + codec := memo.NewCodecABI() + + // add output arguments + output := make([]memo.CodecArg, len(tc.expected)) + for i, arg := range tc.expected { + output[i] = memo.NewArg(arg.Name, arg.Type, newArgInstance(arg.Arg)) + } + codec.AddArguments(output...) + + // unpack arguments from ABI-encoded data + err := codec.UnpackArguments(tc.data) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // validate the unpacked arguments values + require.NoError(t, err) + for i, arg := range tc.expected { + ensureArgEquality(t, arg.Arg, output[i].Arg) + } + }) + } +} diff --git a/pkg/memo/codec_compact.go b/pkg/memo/codec_compact.go new file mode 100644 index 0000000000..bcb49d3f92 --- /dev/null +++ b/pkg/memo/codec_compact.go @@ -0,0 +1,265 @@ +package memo + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" +) + +var _ Codec = (*CodecCompact)(nil) + +// CodecCompact is a coder/decoder for compact encoded memo fields +// +// This encoding format concatenates the memo fields into a single byte array +// with zero padding to minimize the total size of the memo. +type CodecCompact struct { + // lenBytes is the number of bytes used to encode the length of the data + lenBytes int + + // args contains the list of arguments + args []CodecArg +} + +// NewCodecCompact creates a new compact codec +func NewCodecCompact(encodingFmt EncodingFormat) (*CodecCompact, error) { + lenBytes, err := GetLenBytes(encodingFmt) + if err != nil { + return nil, err + } + + return &CodecCompact{ + lenBytes: lenBytes, + args: make([]CodecArg, 0), + }, nil +} + +// AddArguments adds a list of arguments to the codec +func (c *CodecCompact) AddArguments(args ...CodecArg) { + c.args = append(c.args, args...) +} + +// PackArguments packs the arguments into the compact encoded data +func (c *CodecCompact) PackArguments() ([]byte, error) { + data := make([]byte, 0) + + // pack according to argument type + for _, arg := range c.args { + switch arg.Type { + case ArgTypeBytes: + dataBytes, err := c.packBytes(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack bytes argument: %s", arg.Name) + } + data = append(data, dataBytes...) + case ArgTypeAddress: + dataAddress, err := c.packAddress(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack address argument: %s", arg.Name) + } + data = append(data, dataAddress...) + case ArgTypeString: + dataString, err := c.packString(arg.Arg) + if err != nil { + return nil, errors.Wrapf(err, "failed to pack string argument: %s", arg.Name) + } + data = append(data, dataString...) + default: + return nil, fmt.Errorf("unsupported argument (%s) type: %s", arg.Name, arg.Type) + } + } + + return data, nil +} + +// UnpackArguments unpacks the compact encoded data into the output arguments +func (c *CodecCompact) UnpackArguments(data []byte) error { + // unpack according to argument type + offset := 0 + for _, arg := range c.args { + switch arg.Type { + case ArgTypeBytes: + bytesRead, err := c.unpackBytes(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack bytes argument: %s", arg.Name) + } + offset += bytesRead + case ArgTypeAddress: + bytesRead, err := c.unpackAddress(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack address argument: %s", arg.Name) + } + offset += bytesRead + case ArgTypeString: + bytesRead, err := c.unpackString(data[offset:], arg.Arg) + if err != nil { + return errors.Wrapf(err, "failed to unpack string argument: %s", arg.Name) + } + offset += bytesRead + default: + return fmt.Errorf("unsupported argument (%s) type: %s", arg.Name, arg.Type) + } + } + + // ensure all data is consumed + if offset != len(data) { + return fmt.Errorf("consumed bytes (%d) != total bytes (%d)", offset, len(data)) + } + + return nil +} + +// packLength packs the length of the data into the compact format +func (c *CodecCompact) packLength(length int) ([]byte, error) { + data := make([]byte, c.lenBytes) + + switch c.lenBytes { + case LenBytesShort: + if length > math.MaxUint8 { + return nil, fmt.Errorf("data length %d exceeds %d bytes", length, math.MaxUint8) + } + data[0] = uint8(length) + case LenBytesLong: + if length > math.MaxUint16 { + return nil, fmt.Errorf("data length %d exceeds %d bytes", length, math.MaxUint16) + } + // #nosec G115 range checked + binary.LittleEndian.PutUint16(data, uint16(length)) + } + return data, nil +} + +// packAddress packs argument of type 'address'. +func (c *CodecCompact) packAddress(arg interface{}) ([]byte, error) { + // type assertion + address, ok := arg.(common.Address) + if !ok { + return nil, fmt.Errorf("argument is not of type common.Address") + } + + return address.Bytes(), nil +} + +// packBytes packs argument of type 'bytes'. +func (c *CodecCompact) packBytes(arg interface{}) ([]byte, error) { + // type assertion + bytes, ok := arg.([]byte) + if !ok { + return nil, fmt.Errorf("argument is not of type []byte") + } + + // pack length of the data + data, err := c.packLength(len(bytes)) + if err != nil { + return nil, errors.Wrap(err, "failed to pack length of bytes") + } + + // append the data + data = append(data, bytes...) + return data, nil +} + +// packString packs argument of type 'string'. +func (c *CodecCompact) packString(arg interface{}) ([]byte, error) { + // type assertion + str, ok := arg.(string) + if !ok { + return nil, fmt.Errorf("argument is not of type string") + } + + // pack length of the data + data, err := c.packLength(len([]byte(str))) + if err != nil { + return nil, errors.Wrap(err, "failed to pack length of string") + } + + // append the string + data = append(data, []byte(str)...) + return data, nil +} + +// unpackLength returns the length of the data encoded in the compact format +func (c *CodecCompact) unpackLength(data []byte) (int, error) { + if len(data) < c.lenBytes { + return 0, fmt.Errorf("expected %d bytes to decode length, got %d", c.lenBytes, len(data)) + } + + // decode length of the data + length := 0 + switch c.lenBytes { + case LenBytesShort: + length = int(data[0]) + case LenBytesLong: + // convert little-endian bytes to integer + length = int(binary.LittleEndian.Uint16(data[:2])) + } + + // ensure remaining data is long enough + if len(data) < c.lenBytes+length { + return 0, fmt.Errorf("expected %d bytes, got %d", length, len(data)-c.lenBytes) + } + + return length, nil +} + +// unpackAddress unpacks argument of type 'address'. +func (c *CodecCompact) unpackAddress(data []byte, output interface{}) (int, error) { + // type assertion + pAddress, ok := output.(*common.Address) + if !ok { + return 0, fmt.Errorf("argument is not of type *common.Address") + } + + // ensure remaining data >= 20 bytes + if len(data) < common.AddressLength { + return 0, fmt.Errorf("expected address, got %d bytes", len(data)) + } + *pAddress = common.BytesToAddress((data[:20])) + + return common.AddressLength, nil +} + +// unpackBytes unpacks argument of type 'bytes' and returns the number of bytes read. +func (c *CodecCompact) unpackBytes(data []byte, output interface{}) (int, error) { + // type assertion + pSlice, ok := output.(*[]byte) + if !ok { + return 0, fmt.Errorf("argument is not of type *[]byte") + } + + // unpack length + dataLen, err := c.unpackLength(data) + if err != nil { + return 0, errors.Wrap(err, "failed to unpack length of bytes") + } + + // make a copy of the data + if dataLen > 0 { + *pSlice = make([]byte, dataLen) + copy(*pSlice, data[c.lenBytes:c.lenBytes+dataLen]) + } + + return c.lenBytes + dataLen, nil +} + +// unpackString unpacks argument of type 'string' and returns the number of bytes read. +func (c *CodecCompact) unpackString(data []byte, output interface{}) (int, error) { + // type assertion + pString, ok := output.(*string) + if !ok { + return 0, fmt.Errorf("argument is not of type *string") + } + + // unpack length + strLen, err := c.unpackLength(data) + if err != nil { + return 0, errors.Wrap(err, "failed to unpack length of string") + } + + // make a copy of the string + *pString = string(data[c.lenBytes : c.lenBytes+strLen]) + + return c.lenBytes + strLen, nil +} diff --git a/pkg/memo/codec_compact_test.go b/pkg/memo/codec_compact_test.go new file mode 100644 index 0000000000..328beeb7e4 --- /dev/null +++ b/pkg/memo/codec_compact_test.go @@ -0,0 +1,413 @@ +package memo_test + +import ( + "bytes" + "encoding/binary" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +// CompactPack is a helper function to pack arguments into compact encoded data +// Note: all arguments are assumed to be <= 65535 bytes for simplicity. +func CompactPack(encodingFmt memo.EncodingFormat, args ...memo.CodecArg) []byte { + var ( + length int + packedData []byte + ) + + for _, arg := range args { + // get length of argument + switch arg.Type { + case memo.ArgTypeBytes: + length = len(arg.Arg.([]byte)) + case memo.ArgTypeString: + length = len([]byte(arg.Arg.(string))) + default: + // skip length for other types + length = -1 + } + + // append length in bytes + if length != -1 { + switch encodingFmt { + case memo.EncodingFmtCompactShort: + packedData = append(packedData, byte(length)) + case memo.EncodingFmtCompactLong: + buff := make([]byte, 2) + binary.LittleEndian.PutUint16(buff, uint16(length)) + packedData = append(packedData, buff...) + } + } + + // append actual data in bytes + switch arg.Type { + case memo.ArgTypeBytes: + packedData = append(packedData, arg.Arg.([]byte)...) + case memo.ArgTypeAddress: + packedData = append(packedData, arg.Arg.(common.Address).Bytes()...) + case memo.ArgTypeString: + packedData = append(packedData, []byte(arg.Arg.(string))...) + } + } + + return packedData +} + +func Test_NewCodecCompact(t *testing.T) { + tests := []struct { + name string + encodeFmt memo.EncodingFormat + fail bool + }{ + { + name: "create codec compact successfully", + encodeFmt: memo.EncodingFmtCompactShort, + }, + { + name: "create codec compact failed on invalid encoding format", + encodeFmt: 0b11, + fail: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + codec, err := memo.NewCodecCompact(tc.encodeFmt) + if tc.fail { + require.Error(t, err) + require.Nil(t, codec) + } else { + require.NoError(t, err) + require.NotNil(t, codec) + } + }) + } +} + +func Test_CodecCompact_AddArguments(t *testing.T) { + codec, err := memo.NewCodecCompact(memo.EncodingFmtCompactLong) + require.NoError(t, err) + require.NotNil(t, codec) + + address := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") + codec.AddArguments(memo.ArgReceiver(address)) + + // attempt to pack the arguments, result should not be nil + packedData, err := codec.PackArguments() + require.NoError(t, err) + require.True(t, len(packedData) == common.AddressLength) +} + +func Test_CodecCompact_PackArguments(t *testing.T) { + // create sample arguments + argAddress := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") + argBytes := []byte("here is a bytes argument") + argString := "some other string argument" + + // test cases + tests := []struct { + name string + encodeFmt memo.EncodingFormat + args []memo.CodecArg + expectedLen int + errMsg string + }{ + { + name: "pack arguments of [address, bytes, string] in compact-short format", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + expectedLen: 20 + 1 + len(argBytes) + 1 + len([]byte(argString)), + }, + { + name: "pack arguments of [string, address, bytes] in compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + expectedLen: 2 + len([]byte(argString)) + 20 + 2 + len(argBytes), + }, + { + name: "pack long string (> 255 bytes) with compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(strings.Repeat("a", 256))), + }, + expectedLen: 2 + 256, + }, + { + name: "pack long string (> 255 bytes) with compact-short format should fail", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(strings.Repeat("b", 256))), + }, + errMsg: "exceeds 255 bytes", + }, + { + name: "pack long string (> 65535 bytes) with compact-long format should fail", + encodeFmt: memo.EncodingFmtCompactLong, + args: []memo.CodecArg{ + memo.ArgPayload([]byte(strings.Repeat("c", 65536))), + }, + errMsg: "exceeds 65535 bytes", + }, + { + name: "pack empty byte array and string arguments", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + expectedLen: 2, + }, + { + name: "failed to pack bytes argument if string is passed", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgPayload(argString), // expect bytes type, but passed string + }, + errMsg: "argument is not of type []byte", + }, + { + name: "failed to pack address argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "argument is not of type common.Address", + }, + { + name: "failed to pack string argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.ArgRevertAddress(argBytes), // expect string type, but passed bytes + }, + errMsg: "argument is not of type string", + }, + { + name: "failed to pack unsupported argument type", + encodeFmt: memo.EncodingFmtCompactShort, + args: []memo.CodecArg{ + memo.NewArg("receiver", memo.ArgType("unknown"), nil), + }, + errMsg: "unsupported argument (receiver) type", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new compact codec and add arguments + codec, err := memo.NewCodecCompact(tc.encodeFmt) + require.NoError(t, err) + codec.AddArguments(tc.args...) + + // pack arguments + packedData, err := codec.PackArguments() + + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + require.Nil(t, packedData) + return + } + require.NoError(t, err) + require.Equal(t, tc.expectedLen, len(packedData)) + + // calc expected data for comparison + expectedData := CompactPack(tc.encodeFmt, tc.args...) + + // validate the packed data + require.True(t, bytes.Equal(expectedData, packedData), "compact encoded data mismatch") + }) + } +} + +func Test_CodecCompact_UnpackArguments(t *testing.T) { + // create sample arguments + argAddress := common.HexToAddress("0x855EfD3C54F9Ed106C6c3FB343539c89Df042e0B") + argBytes := []byte("some test bytes argument") + argString := "some other string argument" + + // test cases + tests := []struct { + name string + encodeFmt memo.EncodingFormat + data []byte + expected []memo.CodecArg + errMsg string + }{ + { + name: "unpack arguments of [address, bytes, string] in compact-short format", + encodeFmt: memo.EncodingFmtCompactShort, + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + memo.ArgRevertAddress(argString), + }, + }, + { + name: "unpack arguments of [string, address, bytes] in compact-long format", + encodeFmt: memo.EncodingFmtCompactLong, + data: CompactPack( + memo.EncodingFmtCompactLong, + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + ), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + }, + { + name: "unpack empty byte array and string argument", + encodeFmt: memo.EncodingFmtCompactShort, + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + ), + expected: []memo.CodecArg{ + memo.ArgPayload([]byte{}), + memo.ArgRevertAddress(""), + }, + }, + { + name: "failed to unpack address if data length < 20 bytes", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + }, + errMsg: "expected address, got 5 bytes", + }, + { + name: "failed to unpack string if data length < 1 byte", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{}, + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argString), + }, + errMsg: "expected 1 bytes to decode length", + }, + { + name: "failed to unpack string if actual data is less than decoded length", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{0x05, 0x0a, 0x0b, 0x0c, 0x0d}, // length = 5, but only 4 bytes provided + expected: []memo.CodecArg{ + memo.ArgPayload(argBytes), + }, + errMsg: "expected 5 bytes, got 4", + }, + { + name: "failed to unpack bytes argument if string is passed", + encodeFmt: memo.EncodingFmtCompactShort, + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgPayload(argBytes), + ), + expected: []memo.CodecArg{ + memo.ArgPayload(argString), // expect bytes type, but passed string + }, + errMsg: "argument is not of type *[]byte", + }, + { + name: "failed to unpack address argument if bytes is passed", + encodeFmt: memo.EncodingFmtCompactShort, + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + ), + expected: []memo.CodecArg{ + memo.ArgReceiver(argBytes), // expect address type, but passed bytes + }, + errMsg: "argument is not of type *common.Address", + }, + { + name: "failed to unpack string argument if address is passed", + encodeFmt: memo.EncodingFmtCompactShort, + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgRevertAddress(argString), + ), + expected: []memo.CodecArg{ + memo.ArgRevertAddress(argAddress), // expect string type, but passed address + }, + errMsg: "argument is not of type *string", + }, + { + name: "failed to unpack unsupported argument type", + encodeFmt: memo.EncodingFmtCompactShort, + data: []byte{}, + expected: []memo.CodecArg{ + memo.NewArg("payload", memo.ArgType("unknown"), nil), + }, + errMsg: "unsupported argument (payload) type", + }, + { + name: "unpacking should fail if not all data is consumed", + encodeFmt: memo.EncodingFmtCompactShort, + data: func() []byte { + data := CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + ) + // append 1 extra byte + return append(data, 0x00) + }(), + expected: []memo.CodecArg{ + memo.ArgReceiver(argAddress), + memo.ArgPayload(argBytes), + }, + errMsg: "consumed bytes (45) != total bytes (46)", + }, + } + + // loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // create a new compact codec and add arguments + codec, err := memo.NewCodecCompact(tc.encodeFmt) + require.NoError(t, err) + + // add output arguments + output := make([]memo.CodecArg, len(tc.expected)) + for i, arg := range tc.expected { + output[i] = memo.NewArg(arg.Name, arg.Type, newArgInstance(arg.Arg)) + } + codec.AddArguments(output...) + + // unpack arguments from compact-encoded data + err = codec.UnpackArguments(tc.data) + + // validate error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // validate the unpacked arguments values + require.NoError(t, err) + for i, arg := range tc.expected { + ensureArgEquality(t, arg.Arg, output[i].Arg) + } + }) + } +} diff --git a/pkg/memo/codec_test.go b/pkg/memo/codec_test.go new file mode 100644 index 0000000000..aa27667967 --- /dev/null +++ b/pkg/memo/codec_test.go @@ -0,0 +1,92 @@ +package memo_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +func Test_GetLenBytes(t *testing.T) { + // Define table-driven test cases + tests := []struct { + name string + encodeFmt memo.EncodingFormat + expectedLen int + expectErr bool + }{ + { + name: "compact short", + encodeFmt: memo.EncodingFmtCompactShort, + expectedLen: 1, + }, + { + name: "compact long", + encodeFmt: memo.EncodingFmtCompactLong, + expectedLen: 2, + }, + { + name: "non-compact encoding format", + encodeFmt: memo.EncodingFmtABI, + expectedLen: 0, + expectErr: true, + }, + } + + // Loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + length, err := memo.GetLenBytes(tc.encodeFmt) + + // Check if error is expected + if tc.expectErr { + require.Error(t, err) + require.Equal(t, 0, length) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedLen, length) + } + }) + } +} + +func Test_GetCodec(t *testing.T) { + // Define table-driven test cases + tests := []struct { + name string + encodeFmt memo.EncodingFormat + errMsg string + }{ + { + name: "should get ABI codec", + encodeFmt: memo.EncodingFmtABI, + }, + { + name: "should get compact codec", + encodeFmt: memo.EncodingFmtCompactShort, + }, + { + name: "should get compact codec", + encodeFmt: memo.EncodingFmtCompactLong, + }, + { + name: "should fail to get codec", + encodeFmt: 0b0011, + errMsg: "invalid encoding format", + }, + } + + // Loop through each test case + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + codec, err := memo.GetCodec(tc.encodeFmt) + if tc.errMsg != "" { + require.Error(t, err) + require.Nil(t, codec) + } else { + require.NoError(t, err) + require.NotNil(t, codec) + } + }) + } +} diff --git a/pkg/memo/fields.go b/pkg/memo/fields.go new file mode 100644 index 0000000000..fff853f955 --- /dev/null +++ b/pkg/memo/fields.go @@ -0,0 +1,16 @@ +package memo + +// Fields is the interface for memo fields +type Fields interface { + // Pack encodes the memo fields + Pack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8) ([]byte, error) + + // Unpack decodes the memo fields + Unpack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8, data []byte) error + + // Validate checks if the fields are valid + Validate(opCode OpCode, dataFlags uint8) error + + // DataFlags build the data flags for the fields + DataFlags() uint8 +} diff --git a/pkg/memo/fields_v0.go b/pkg/memo/fields_v0.go new file mode 100644 index 0000000000..a8f79d99ba --- /dev/null +++ b/pkg/memo/fields_v0.go @@ -0,0 +1,208 @@ +package memo + +import ( + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" + + "github.com/zeta-chain/node/pkg/crypto" + zetabits "github.com/zeta-chain/node/pkg/math/bits" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" +) + +// Enum of the bit position of each memo fields +const ( + bitPosReceiver uint8 = 0 // receiver + bitPosPayload uint8 = 1 // payload + bitPosRevertAddress uint8 = 2 // revertAddress + bitPosAbortAddress uint8 = 3 // abortAddress + bitPosRevertMessage uint8 = 4 // revertMessage +) + +const ( + // MaskFlagsReserved is the mask for reserved data flags + MaskFlagsReserved = 0b11100000 +) + +var _ Fields = (*FieldsV0)(nil) + +// FieldsV0 contains the data fields of the inbound memo V0 +type FieldsV0 struct { + // Receiver is the ZEVM receiver address + Receiver common.Address + + // Payload is the calldata passed to ZEVM contract call + Payload []byte + + // RevertOptions is the options for cctx revert handling + RevertOptions crosschaintypes.RevertOptions +} + +// Pack encodes the memo fields +func (f *FieldsV0) Pack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8) ([]byte, error) { + // validate fields + err := f.Validate(opCode, dataFlags) + if err != nil { + return nil, err + } + + codec, err := GetCodec(encodingFmt) + if err != nil { + return nil, errors.Wrap(err, "unable to get codec") + } + + return f.packFields(codec, dataFlags) +} + +// Unpack decodes the memo fields +func (f *FieldsV0) Unpack(opCode OpCode, encodingFmt EncodingFormat, dataFlags uint8, data []byte) error { + codec, err := GetCodec(encodingFmt) + if err != nil { + return errors.Wrap(err, "unable to get codec") + } + + err = f.unpackFields(codec, dataFlags, data) + if err != nil { + return err + } + + return f.Validate(opCode, dataFlags) +} + +// Validate checks if the fields are valid +func (f *FieldsV0) Validate(opCode OpCode, dataFlags uint8) error { + // receiver address must be a valid address + if zetabits.IsBitSet(dataFlags, bitPosReceiver) && crypto.IsEmptyAddress(f.Receiver) { + return errors.New("receiver address is empty") + } + + // payload is not allowed for deposit operation + if opCode == OpCodeDeposit && len(f.Payload) > 0 { + return errors.New("payload is not allowed for deposit operation") + } + + // abort address must be a valid address + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) && !common.IsHexAddress(f.RevertOptions.AbortAddress) { + return errors.New("invalid abort address") + } + + // revert message is not allowed when CallOnRevert is false + // 1. it's a good-to-have check to make the fields semantically correct. + // 2. unpacking won't hit this error as the codec will catch it earlier. + if !f.RevertOptions.CallOnRevert && len(f.RevertOptions.RevertMessage) > 0 { + return errors.New("revert message is not allowed when CallOnRevert is false") + } + + return nil +} + +// DataFlags build the data flags from actual fields +func (f *FieldsV0) DataFlags() uint8 { + var dataFlags uint8 + + // set 'receiver' flag if provided + if !crypto.IsEmptyAddress(f.Receiver) { + zetabits.SetBit(&dataFlags, bitPosReceiver) + } + + // set 'payload' flag if provided + if len(f.Payload) > 0 { + zetabits.SetBit(&dataFlags, bitPosPayload) + } + + // set 'revertAddress' flag if provided + if f.RevertOptions.RevertAddress != "" { + zetabits.SetBit(&dataFlags, bitPosRevertAddress) + } + + // set 'abortAddress' flag if provided + if f.RevertOptions.AbortAddress != "" { + zetabits.SetBit(&dataFlags, bitPosAbortAddress) + } + + // set 'revertMessage' flag if provided + if f.RevertOptions.CallOnRevert { + zetabits.SetBit(&dataFlags, bitPosRevertMessage) + } + + return dataFlags +} + +// packFieldsV0 packs the memo fields for version 0 +func (f *FieldsV0) packFields(codec Codec, dataFlags uint8) ([]byte, error) { + // add 'receiver' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosReceiver) { + codec.AddArguments(ArgReceiver(f.Receiver)) + } + + // add 'payload' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosPayload) { + codec.AddArguments(ArgPayload(f.Payload)) + } + + // add 'revertAddress' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosRevertAddress) { + codec.AddArguments(ArgRevertAddress(f.RevertOptions.RevertAddress)) + } + + // add 'abortAddress' argument optionally + abortAddress := common.HexToAddress(f.RevertOptions.AbortAddress) + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) { + codec.AddArguments(ArgAbortAddress(abortAddress)) + } + + // add 'revertMessage' argument optionally + if f.RevertOptions.CallOnRevert { + codec.AddArguments(ArgRevertMessage(f.RevertOptions.RevertMessage)) + } + + // pack the codec arguments into data + data, err := codec.PackArguments() + if err != nil { // never happens + return nil, errors.Wrap(err, "failed to pack arguments") + } + + return data, nil +} + +// unpackFields unpacks the memo fields for version 0 +func (f *FieldsV0) unpackFields(codec Codec, dataFlags byte, data []byte) error { + // add 'receiver' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosReceiver) { + codec.AddArguments(ArgReceiver(&f.Receiver)) + } + + // add 'payload' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosPayload) { + codec.AddArguments(ArgPayload(&f.Payload)) + } + + // add 'revertAddress' argument optionally + if zetabits.IsBitSet(dataFlags, bitPosRevertAddress) { + codec.AddArguments(ArgRevertAddress(&f.RevertOptions.RevertAddress)) + } + + // add 'abortAddress' argument optionally + var abortAddress common.Address + if zetabits.IsBitSet(dataFlags, bitPosAbortAddress) { + codec.AddArguments(ArgAbortAddress(&abortAddress)) + } + + // add 'revertMessage' argument optionally + f.RevertOptions.CallOnRevert = zetabits.IsBitSet(dataFlags, bitPosRevertMessage) + if f.RevertOptions.CallOnRevert { + codec.AddArguments(ArgRevertMessage(&f.RevertOptions.RevertMessage)) + } + + // unpack the data (after flags) into codec arguments + err := codec.UnpackArguments(data) + if err != nil { + return errors.Wrap(err, "failed to unpack arguments") + } + + // convert abort address to string + if !crypto.IsEmptyAddress(abortAddress) { + f.RevertOptions.AbortAddress = abortAddress.Hex() + } + + return nil +} diff --git a/pkg/memo/fields_v0_test.go b/pkg/memo/fields_v0_test.go new file mode 100644 index 0000000000..13742422c2 --- /dev/null +++ b/pkg/memo/fields_v0_test.go @@ -0,0 +1,388 @@ +package memo_test + +import ( + "bytes" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + "github.com/zeta-chain/node/testutil/sample" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" +) + +const ( + // flagsAllFieldsSet sets all fields: [receiver, payload, revert address, abort address, CallOnRevert] + flagsAllFieldsSet = 0b00011111 +) + +func Test_V0_Pack(t *testing.T) { + // create sample fields + fAddress := sample.EthAddress() + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + opCode memo.OpCode + encodeFmt memo.EncodingFormat + dataFlags uint8 + fields memo.FieldsV0 + expectedFlags byte + expectedData []byte + errMsg string + }{ + { + name: "pack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: flagsAllFieldsSet, // all fields are set + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + expectedData: ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "pack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: flagsAllFieldsSet, // all fields are set + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + expectedData: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000001, // receiver flag is set + fields: memo.FieldsV0{ + Receiver: common.Address{}, + }, + errMsg: "receiver address is empty", + }, + { + name: "unable to get codec on invalid encoding format", + opCode: memo.OpCodeDepositAndCall, + fields: memo.FieldsV0{ + Receiver: fAddress, + }, + encodeFmt: 0x0F, + errMsg: "unable to get codec", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // pack the fields + data, err := tc.fields.Pack(tc.opCode, tc.encodeFmt, tc.dataFlags) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + require.Nil(t, data) + return + } + + // compare the fields + require.NoError(t, err) + require.True(t, bytes.Equal(tc.expectedData, data)) + }) + } +} + +func Test_V0_Unpack(t *testing.T) { + // create sample fields + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + opCode memo.OpCode + encodeFmt memo.EncodingFormat + dataFlags byte + data []byte + expected memo.FieldsV0 + errMsg string + }{ + { + name: "unpack all fields with ABI encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: flagsAllFieldsSet, // all fields are set + data: ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "unpack all fields with compact encoding", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: flagsAllFieldsSet, // all fields are set + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expected: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "unpack empty ABI encoded payload if flag is set", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000010, // payload flags are set + data: ABIPack(t, + memo.ArgPayload([]byte{})), // empty payload + expected: memo.FieldsV0{}, + }, + { + name: "unpack empty compact encoded payload if flag is set", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: 0b00000010, // payload flag is set + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgPayload([]byte{})), // empty payload + expected: memo.FieldsV0{}, + }, + { + name: "unable to get codec on invalid encoding format", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: 0x0F, + dataFlags: 0b00000001, + data: []byte{}, + errMsg: "unable to get codec", + }, + { + name: "failed to unpack ABI encoded data with compact encoding format", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtCompactShort, + dataFlags: 0b00000011, // receiver and payload flags are set + data: ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes)), + errMsg: "failed to unpack arguments", + }, + { + name: "fields validation failed due to empty receiver address", + opCode: memo.OpCodeDepositAndCall, + encodeFmt: memo.EncodingFmtABI, + dataFlags: 0b00000011, // receiver and payload flags are set + data: ABIPack(t, + memo.ArgReceiver(common.Address{}), + memo.ArgPayload(fBytes)), + errMsg: "receiver address is empty", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // unpack the fields + fields := memo.FieldsV0{} + err := fields.Unpack(tc.opCode, tc.encodeFmt, tc.dataFlags, tc.data) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + + // compare the fields + require.NoError(t, err) + require.Equal(t, tc.expected, fields) + }) + } +} + +func Test_V0_Validate(t *testing.T) { + // create sample fields + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + opCode memo.OpCode + dataFlags uint8 + fields memo.FieldsV0 + errMsg string + }{ + { + name: "valid fields", + opCode: memo.OpCodeDepositAndCall, + dataFlags: flagsAllFieldsSet, // all fields are set + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + }, + { + name: "invalid receiver address", + opCode: memo.OpCodeCall, + dataFlags: 0b00000001, // receiver flag is set + fields: memo.FieldsV0{ + Receiver: common.Address{}, // provide empty receiver address + }, + errMsg: "receiver address is empty", + }, + { + name: "payload is not allowed when opCode is deposit", + opCode: memo.OpCodeDeposit, + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, // payload is mistakenly set + }, + errMsg: "payload is not allowed for deposit operation", + }, + { + name: "abort address is invalid", + opCode: memo.OpCodeDeposit, + dataFlags: 0b00001000, // abort address flag is set + fields: memo.FieldsV0{ + RevertOptions: crosschaintypes.RevertOptions{ + AbortAddress: "invalid abort address", + }, + }, + errMsg: "invalid abort address", + }, + { + name: "revert message is not allowed when CallOnRevert is false", + opCode: memo.OpCodeDeposit, + fields: memo.FieldsV0{ + Receiver: fAddress, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: false, // CallOnRevert is false + RevertMessage: []byte("revert message"), // revert message is mistakenly set + }, + }, + errMsg: "revert message is not allowed when CallOnRevert is false", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // validate the fields + err := tc.fields.Validate(tc.opCode, tc.dataFlags) + + // validate the error message + if tc.errMsg != "" { + require.ErrorContains(t, err, tc.errMsg) + return + } + require.NoError(t, err) + }) + } +} + +func Test_V0_DataFlags(t *testing.T) { + // create sample fields + fAddress := common.HexToAddress("0xA029D053E13223E2442E28be80b3CeDA27ecbE31") + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + fields memo.FieldsV0 + expectedFlags uint8 + }{ + { + name: "all fields set", + fields: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), + RevertMessage: fBytes, + }, + }, + expectedFlags: flagsAllFieldsSet, + }, + { + name: "no fields set", + fields: memo.FieldsV0{}, + expectedFlags: 0b00000000, + }, + { + name: "a few fields set", + fields: memo.FieldsV0{ + Receiver: fAddress, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + RevertMessage: fBytes, + }, + }, + expectedFlags: 0b00010101, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // get the data flags + flags := tc.fields.DataFlags() + + // compare the flags + require.Equal(t, tc.expectedFlags, flags) + }) + } +} diff --git a/pkg/memo/header.go b/pkg/memo/header.go new file mode 100644 index 0000000000..5f9bcd1e14 --- /dev/null +++ b/pkg/memo/header.go @@ -0,0 +1,139 @@ +package memo + +import ( + "fmt" + + "github.com/pkg/errors" + + zetabits "github.com/zeta-chain/node/pkg/math/bits" +) + +type OpCode uint8 + +const ( + // Identifier is the ASCII code of 'Z' (0x5A) + Identifier byte = 0x5A + + // HeaderSize is the size of the memo header: [identifier + ctrlByte1+ ctrlByte2 + dataFlags] + HeaderSize = 4 + + // maskVersion is the mask for the version bits(upper 4 bits) + maskVersion byte = 0b11110000 + + // maskEncodingFormat is the mask for the encoding format bits(lower 4 bits) + maskEncodingFormat byte = 0b00001111 + + // maskOpCode is the mask for the operation code bits(upper 4 bits) + maskOpCode byte = 0b11110000 + + // maskCtrlReserved is the mask for reserved control bits (lower 4 bits) + maskCtrlReserved byte = 0b00001111 +) + +// Enum for non-EVM chain inbound operation code (4 bits) +const ( + OpCodeDeposit OpCode = 0b0000 // operation 'deposit' + OpCodeDepositAndCall OpCode = 0b0001 // operation 'deposit_and_call' + OpCodeCall OpCode = 0b0010 // operation 'call' + OpCodeInvalid OpCode = 0b0011 // invalid operation code +) + +// Header represent the memo header +type Header struct { + // Version is the memo Version + Version uint8 + + // EncodingFmt is the memo encoding format + EncodingFmt EncodingFormat + + // OpCode is the inbound operation code + OpCode OpCode + + // Reserved is the reserved control bits + Reserved uint8 + + // DataFlags is the data flags + DataFlags uint8 +} + +// EncodeToBytes encodes the memo header to raw bytes +func (h *Header) EncodeToBytes() ([]byte, error) { + // validate header + if err := h.Validate(); err != nil { + return nil, err + } + + // create buffer for the header + data := make([]byte, HeaderSize) + + // set byte-0 as memo identifier + data[0] = Identifier + + // set version #, encoding format + var ctrlByte1 byte + ctrlByte1 = zetabits.SetBits(ctrlByte1, maskVersion, h.Version) + ctrlByte1 = zetabits.SetBits(ctrlByte1, maskEncodingFormat, byte(h.EncodingFmt)) + data[1] = ctrlByte1 + + // set operation code, reserved bits + var ctrlByte2 byte + ctrlByte2 = zetabits.SetBits(ctrlByte2, maskOpCode, byte(h.OpCode)) + ctrlByte2 = zetabits.SetBits(ctrlByte2, maskCtrlReserved, h.Reserved) + data[2] = ctrlByte2 + + // set data flags + data[3] = h.DataFlags + + return data, nil +} + +// DecodeFromBytes decodes the memo header from the given data +func (h *Header) DecodeFromBytes(data []byte) error { + // memo data must be longer than the header size + if len(data) < HeaderSize { + return errors.New("memo is too short") + } + + // byte-0 is the memo identifier + if data[0] != Identifier { + return fmt.Errorf("invalid memo identifier: %d", data[0]) + } + + // extract version #, encoding format + ctrlByte1 := data[1] + h.Version = zetabits.GetBits(ctrlByte1, maskVersion) + h.EncodingFmt = EncodingFormat(zetabits.GetBits(ctrlByte1, maskEncodingFormat)) + + // extract operation code, reserved bits + ctrlByte2 := data[2] + h.OpCode = OpCode(zetabits.GetBits(ctrlByte2, maskOpCode)) + h.Reserved = zetabits.GetBits(ctrlByte2, maskCtrlReserved) + + // extract data flags + h.DataFlags = data[3] + + // validate header + return h.Validate() +} + +// Validate checks if the memo header is valid +func (h *Header) Validate() error { + if h.Version != 0 { + return fmt.Errorf("invalid memo version: %d", h.Version) + } + + if h.EncodingFmt >= EncodingFmtInvalid { + return fmt.Errorf("invalid encoding format: %d", h.EncodingFmt) + } + + if h.OpCode >= OpCodeInvalid { + return fmt.Errorf("invalid operation code: %d", h.OpCode) + } + + // reserved control bits must be zero + if h.Reserved != 0 { + return fmt.Errorf("reserved control bits are not zero: %d", h.Reserved) + } + + return nil +} diff --git a/pkg/memo/header_test.go b/pkg/memo/header_test.go new file mode 100644 index 0000000000..9d89e1f1d6 --- /dev/null +++ b/pkg/memo/header_test.go @@ -0,0 +1,167 @@ +package memo_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" +) + +// MakeHead is a helper function to create a memo head +// Note: all arguments are assumed to be <= 0b1111 for simplicity. +func MakeHead(version, encodingFmt, opCode, reserved, flags uint8) []byte { + head := make([]byte, memo.HeaderSize) + head[0] = memo.Identifier + head[1] = version<<4 | encodingFmt + head[2] = opCode<<4 | reserved + head[3] = flags + return head +} + +func Test_Header_EncodeToBytes(t *testing.T) { + tests := []struct { + name string + header memo.Header + expected []byte + errMsg string + }{ + { + name: "it works", + header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + DataFlags: 0b00011111, + }, + expected: []byte{memo.Identifier, 0b00000000, 0b00100000, 0b00011111}, + }, + { + name: "header validation failed", + header: memo.Header{ + Version: 1, // invalid version + }, + errMsg: "invalid memo version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header, err := tt.header.EncodeToBytes() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + require.Nil(t, header) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, header) + }) + } +} + +func Test_Header_DecodeFromBytes(t *testing.T) { + tests := []struct { + name string + data []byte + expected memo.Header + errMsg string + }{ + { + name: "it works", + data: append( + MakeHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeCall), 0, 0), + []byte{0x01, 0x02}...), + expected: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeCall, + Reserved: 0, + }, + }, + { + name: "memo is too short", + data: []byte{0x01, 0x02, 0x03}, + errMsg: "memo is too short", + }, + { + name: "invalid memo identifier", + data: []byte{'M', 0x02, 0x03, 0x04, 0x05}, + errMsg: "invalid memo identifier", + }, + { + name: "header validation failed", + data: append( + MakeHead(0, uint8(memo.EncodingFmtInvalid), uint8(memo.OpCodeCall), 0, 0), + []byte{0x01, 0x02}...), // invalid encoding format + errMsg: "invalid encoding format", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + memo := &memo.Header{} + err := memo.DecodeFromBytes(tt.data) + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, *memo) + }) + } +} + +func Test_Header_Validate(t *testing.T) { + tests := []struct { + name string + header memo.Header + errMsg string + }{ + { + name: "valid header", + header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + }, + }, + { + name: "invalid version", + header: memo.Header{ + Version: 1, + }, + errMsg: "invalid memo version", + }, + { + name: "invalid encoding format", + header: memo.Header{ + EncodingFmt: memo.EncodingFmtInvalid, + }, + errMsg: "invalid encoding format", + }, + { + name: "invalid operation code", + header: memo.Header{ + OpCode: memo.OpCodeInvalid, + }, + errMsg: "invalid operation code", + }, + { + name: "reserved field is not zero", + header: memo.Header{ + Reserved: 1, + }, + errMsg: "reserved control bits are not zero", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.header.Validate() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + }) + } +} diff --git a/pkg/memo/memo.go b/pkg/memo/memo.go new file mode 100644 index 0000000000..30670952b0 --- /dev/null +++ b/pkg/memo/memo.go @@ -0,0 +1,98 @@ +package memo + +import ( + "encoding/hex" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/pkg/errors" +) + +// InboundMemo represents the inbound memo structure for non-EVM chains +type InboundMemo struct { + // Header contains the memo header + Header + + // FieldsV0 contains the memo fields V0 + // Note: add a FieldsV1 if breaking change is needed in the future + FieldsV0 +} + +// EncodeToBytes encodes a InboundMemo struct to raw bytes +// +// Note: +// - Any provided 'DataFlags' is ignored as they are calculated based on the fields set in the memo. +// - The 'RevertGasLimit' is not used for now for non-EVM chains. +func (m *InboundMemo) EncodeToBytes() ([]byte, error) { + // build fields flags + dataFlags := m.FieldsV0.DataFlags() + m.Header.DataFlags = dataFlags + + // encode head + head, err := m.Header.EncodeToBytes() + if err != nil { + return nil, errors.Wrap(err, "failed to encode memo header") + } + + // encode fields based on version + var data []byte + switch m.Version { + case 0: + data, err = m.FieldsV0.Pack(m.OpCode, m.EncodingFmt, dataFlags) + default: + return nil, fmt.Errorf("invalid memo version: %d", m.Version) + } + if err != nil { + return nil, errors.Wrapf(err, "failed to pack memo fields version: %d", m.Version) + } + + return append(head, data...), nil +} + +// DecodeFromBytes decodes a InboundMemo struct from raw bytes +// +// Returns an error if given data is not a valid memo +func DecodeFromBytes(data []byte) (*InboundMemo, error) { + memo := &InboundMemo{} + + // decode header + err := memo.Header.DecodeFromBytes(data) + if err != nil { + return nil, errors.Wrap(err, "failed to decode memo header") + } + + // decode fields based on version + switch memo.Version { + case 0: + err = memo.FieldsV0.Unpack(memo.OpCode, memo.EncodingFmt, memo.Header.DataFlags, data[HeaderSize:]) + default: + return nil, fmt.Errorf("invalid memo version: %d", memo.Version) + } + if err != nil { + return nil, errors.Wrapf(err, "failed to unpack memo fields version: %d", memo.Version) + } + + return memo, nil +} + +// DecodeLegacyMemoHex decodes hex encoded memo message into address and calldata +// +// The layout of legacy memo is: [20-byte address, variable calldata] +func DecodeLegacyMemoHex(message string) (common.Address, []byte, error) { + if len(message) == 0 { + return common.Address{}, nil, nil + } + + data, err := hex.DecodeString(message) + if err != nil { + return common.Address{}, nil, errors.Wrap(err, "message should be a hex encoded string") + } + + if len(data) < common.AddressLength { + return common.Address{}, data, nil + } + + address := common.BytesToAddress(data[:common.AddressLength]) + data = data[common.AddressLength:] + return address, data, nil +} diff --git a/pkg/memo/memo_test.go b/pkg/memo/memo_test.go new file mode 100644 index 0000000000..e6cb067793 --- /dev/null +++ b/pkg/memo/memo_test.go @@ -0,0 +1,306 @@ +package memo_test + +import ( + "encoding/hex" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" + "github.com/zeta-chain/node/pkg/memo" + crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" +) + +func Test_Memo_EncodeToBytes(t *testing.T) { + // create sample fields + fAddress := common.HexToAddress("0xEA9808f0Ac504d1F521B5BbdfC33e6f1953757a7") + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + memo *memo.InboundMemo + expectedHead []byte + expectedData []byte + errMsg string + }{ + { + name: "encode memo with ABI encoding", + memo: &memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + expectedHead: MakeHead( + 0, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), + 0, + flagsAllFieldsSet, // all fields are set + ), + expectedData: ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "encode memo with compact encoding", + memo: &memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtCompactShort, + OpCode: memo.OpCodeDepositAndCall, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + expectedHead: MakeHead( + 0, + uint8(memo.EncodingFmtCompactShort), + uint8(memo.OpCodeDepositAndCall), + 0, + flagsAllFieldsSet, // all fields are set + ), + expectedData: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + }, + { + name: "failed to encode memo header", + memo: &memo.InboundMemo{ + Header: memo.Header{ + OpCode: memo.OpCodeInvalid, // invalid operation code + }, + }, + errMsg: "failed to encode memo header", + }, + { + name: "failed to encode if version is invalid", + memo: &memo.InboundMemo{ + Header: memo.Header{ + Version: 1, + }, + }, + errMsg: "invalid memo version", + }, + { + name: "failed to pack memo fields", + memo: &memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDeposit, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, // payload is not allowed for deposit + }, + }, + errMsg: "failed to pack memo fields version: 0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := tt.memo.EncodeToBytes() + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + require.Nil(t, data) + return + } + require.NoError(t, err) + require.Equal(t, append(tt.expectedHead, tt.expectedData...), data) + + // decode the memo and compare with the original + decodedMemo, err := memo.DecodeFromBytes(data) + require.NoError(t, err) + require.Equal(t, tt.memo, decodedMemo) + }) + } +} + +func Test_Memo_DecodeFromBytes(t *testing.T) { + // create sample fields + fAddress := common.HexToAddress("0xEA9808f0Ac504d1F521B5BbdfC33e6f1953757a7") + fBytes := []byte("here_s_some_bytes_field") + fString := "this_is_a_string_field" + + tests := []struct { + name string + head []byte + data []byte + expectedMemo memo.InboundMemo + errMsg string + }{ + { + name: "decode memo with ABI encoding", + head: MakeHead( + 0, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), + 0, + flagsAllFieldsSet, // all fields are set + ), + data: ABIPack(t, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expectedMemo: memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtABI, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00011111, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + }, + { + name: "decode memo with compact encoding", + head: MakeHead( + 0, + uint8(memo.EncodingFmtCompactLong), + uint8(memo.OpCodeDepositAndCall), + 0, + flagsAllFieldsSet, // all fields are set + ), + data: CompactPack( + memo.EncodingFmtCompactLong, + memo.ArgReceiver(fAddress), + memo.ArgPayload(fBytes), + memo.ArgRevertAddress(fString), + memo.ArgAbortAddress(fAddress), + memo.ArgRevertMessage(fBytes)), + expectedMemo: memo.InboundMemo{ + Header: memo.Header{ + Version: 0, + EncodingFmt: memo.EncodingFmtCompactLong, + OpCode: memo.OpCodeDepositAndCall, + DataFlags: 0b00011111, + }, + FieldsV0: memo.FieldsV0{ + Receiver: fAddress, + Payload: fBytes, + RevertOptions: crosschaintypes.RevertOptions{ + RevertAddress: fString, + CallOnRevert: true, + AbortAddress: fAddress.String(), // it's a ZEVM address + RevertMessage: fBytes, + }, + }, + }, + }, + { + name: "failed to decode memo header", + head: MakeHead(0, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeInvalid), 0, 0), + data: ABIPack(t, memo.ArgReceiver(fAddress)), + errMsg: "failed to decode memo header", + }, + { + name: "failed to decode if version is invalid", + head: MakeHead(1, uint8(memo.EncodingFmtABI), uint8(memo.OpCodeDeposit), 0, 0), + data: ABIPack(t, memo.ArgReceiver(fAddress)), + errMsg: "invalid memo version", + }, + { + name: "failed to decode compact encoded data with ABI encoding format", + head: MakeHead( + 0, + uint8(memo.EncodingFmtABI), + uint8(memo.OpCodeDepositAndCall), + 0, + 0, + ), // header says ABI encoding + data: CompactPack( + memo.EncodingFmtCompactShort, + memo.ArgReceiver(fAddress), + ), // but data is compact encoded + errMsg: "failed to unpack memo fields", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data := append(tt.head, tt.data...) + memo, err := memo.DecodeFromBytes(data) + if tt.errMsg != "" { + require.ErrorContains(t, err, tt.errMsg) + return + } + require.NoError(t, err) + require.Equal(t, tt.expectedMemo, *memo) + }) + } +} + +func Test_DecodeLegacyMemoHex(t *testing.T) { + expectedShortMsgResult, err := hex.DecodeString("1a2b3c4d5e6f708192a3b4c5d6e7f808") + require.NoError(t, err) + tests := []struct { + name string + message string + expectAddr common.Address + expectData []byte + wantErr bool + }{ + { + "valid msg", + "95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5", + common.HexToAddress("95222290DD7278Aa3Ddd389Cc1E1d165CC4BAfe5"), + []byte{}, + false, + }, + {"empty msg", "", common.Address{}, nil, false}, + {"invalid hex", "invalidHex", common.Address{}, nil, true}, + {"short msg", "1a2b3c4d5e6f708192a3b4c5d6e7f808", common.Address{}, expectedShortMsgResult, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, data, err := memo.DecodeLegacyMemoHex(tt.message) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectAddr, addr) + require.Equal(t, tt.expectData, data) + } + }) + } +} diff --git a/x/crosschain/keeper/evm_deposit.go b/x/crosschain/keeper/evm_deposit.go index c672686c1a..be08873bba 100644 --- a/x/crosschain/keeper/evm_deposit.go +++ b/x/crosschain/keeper/evm_deposit.go @@ -14,6 +14,7 @@ import ( "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" + "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/x/crosschain/types" fungibletypes "github.com/zeta-chain/node/x/fungible/types" ) @@ -78,7 +79,7 @@ func (k Keeper) HandleEVMDeposit(ctx sdk.Context, cctx *types.CrossChainTx) (boo // in protocol version 2, the destination of the deposit is always the to address, the message is the data to be sent to the contract if cctx.ProtocolContractVersion == types.ProtocolContractVersion_V1 { var parsedAddress ethcommon.Address - parsedAddress, message, err = chains.ParseAddressAndData(cctx.RelayedMessage) + parsedAddress, message, err = memo.DecodeLegacyMemoHex(cctx.RelayedMessage) if err != nil { return false, errors.Wrap(types.ErrUnableToParseAddress, err.Error()) } diff --git a/zetaclient/chains/bitcoin/observer/inbound.go b/zetaclient/chains/bitcoin/observer/inbound.go index fad3c87230..1461096763 100644 --- a/zetaclient/chains/bitcoin/observer/inbound.go +++ b/zetaclient/chains/bitcoin/observer/inbound.go @@ -14,8 +14,8 @@ import ( "github.com/pkg/errors" "github.com/rs/zerolog" - "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" + "github.com/zeta-chain/node/pkg/memo" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/bitcoin" "github.com/zeta-chain/node/zetaclient/chains/interfaces" @@ -419,7 +419,7 @@ func (ob *Observer) GetInboundVoteMessageFromBtcEvent(inbound *BTCInboundEvent) // TODO(revamp): move all compliance related functions in a specific file func (ob *Observer) DoesInboundContainsRestrictedAddress(inTx *BTCInboundEvent) bool { receiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(inTx.MemoBytes)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(inTx.MemoBytes)) if err == nil && parsedAddress != (ethcommon.Address{}) { receiver = parsedAddress.Hex() } diff --git a/zetaclient/chains/evm/observer/inbound.go b/zetaclient/chains/evm/observer/inbound.go index c662fc2d60..aebd661ae3 100644 --- a/zetaclient/chains/evm/observer/inbound.go +++ b/zetaclient/chains/evm/observer/inbound.go @@ -20,9 +20,9 @@ import ( "github.com/zeta-chain/protocol-contracts/v1/pkg/contracts/evm/erc20custody.sol" "github.com/zeta-chain/protocol-contracts/v1/pkg/contracts/evm/zetaconnector.non-eth.sol" - "github.com/zeta-chain/node/pkg/chains" "github.com/zeta-chain/node/pkg/coin" "github.com/zeta-chain/node/pkg/constant" + "github.com/zeta-chain/node/pkg/memo" "github.com/zeta-chain/node/pkg/ticker" "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/evm" @@ -622,7 +622,7 @@ func (ob *Observer) BuildInboundVoteMsgForDepositedEvent( ) *types.MsgVoteInbound { // compliance check maybeReceiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(event.Message)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(event.Message)) if err == nil && parsedAddress != (ethcommon.Address{}) { maybeReceiver = parsedAddress.Hex() } @@ -732,7 +732,7 @@ func (ob *Observer) BuildInboundVoteMsgForTokenSentToTSS( // compliance check maybeReceiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(message) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(message) if err == nil && parsedAddress != (ethcommon.Address{}) { maybeReceiver = parsedAddress.Hex() } diff --git a/zetaclient/compliance/compliance.go b/zetaclient/compliance/compliance.go index 163cca65e4..f0135c3ad9 100644 --- a/zetaclient/compliance/compliance.go +++ b/zetaclient/compliance/compliance.go @@ -7,7 +7,7 @@ import ( ethcommon "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog" - "github.com/zeta-chain/node/pkg/chains" + "github.com/zeta-chain/node/pkg/memo" crosschaintypes "github.com/zeta-chain/node/x/crosschain/types" "github.com/zeta-chain/node/zetaclient/chains/base" "github.com/zeta-chain/node/zetaclient/config" @@ -66,7 +66,7 @@ func PrintComplianceLog( func DoesInboundContainsRestrictedAddress(event *clienttypes.InboundEvent, logger *base.ObserverLogger) bool { // parse memo-specified receiver receiver := "" - parsedAddress, _, err := chains.ParseAddressAndData(hex.EncodeToString(event.Memo)) + parsedAddress, _, err := memo.DecodeLegacyMemoHex(hex.EncodeToString(event.Memo)) if err == nil && parsedAddress != (ethcommon.Address{}) { receiver = parsedAddress.Hex() }