diff --git a/core/services/keystore/keys/tronkey/account.go b/core/services/keystore/keys/tronkey/account.go index 3a93ba82e9f..9c90422d2a7 100644 --- a/core/services/keystore/keys/tronkey/account.go +++ b/core/services/keystore/keys/tronkey/account.go @@ -23,12 +23,12 @@ const ( prefixMainnet = 0x41 // TronBytePrefix is the hex prefix to address TronBytePrefix = byte(prefixMainnet) - // Tron address should should have 20 bytes + 4 checksum + 1 Prefix - AddressLength = 20 + // Tron address should have 21 bytes (20 bytes + 1 byte prefix) + AddressLength = 21 ) // Address represents the 21 byte address of an Tron account. -type Address []byte +type Address [AddressLength]byte // Bytes get bytes from address func (a Address) Bytes() []byte { @@ -41,22 +41,25 @@ func (a Address) Hex() string { } // HexToAddress returns Address with byte values of s. -// If s is larger than len(h), s will be cropped from the left. -func HexToAddress(s string) Address { +func HexToAddress(s string) (Address, error) { addr, err := FromHex(s) if err != nil { - return nil + return Address{}, err } - return addr + // Check if the address starts with '41' and is 21 characters long + if len(addr) != AddressLength || addr[0] != prefixMainnet { + return Address{}, errors.New("invalid Tron address") + } + return Address(addr), nil } // Base58ToAddress returns Address with byte values of s. func Base58ToAddress(s string) (Address, error) { addr, err := DecodeCheck(s) if err != nil { - return nil, err + return Address{}, err } - return addr, nil + return Address(addr), nil } // String implements fmt.Stringer. @@ -79,7 +82,7 @@ func PubkeyToAddress(p ecdsa.PublicKey) Address { addressTron := make([]byte, 0) addressTron = append(addressTron, TronBytePrefix) addressTron = append(addressTron, address.Bytes()...) - return addressTron + return Address(addressTron) } // BytesToHexString encodes bytes as a hex string. @@ -140,8 +143,8 @@ func DecodeCheck(input string) ([]byte, error) { return nil, errors.New("base58 check error") } - // tron address should should have 20 bytes + 4 checksum + 1 Prefix - if len(decodeCheck) != AddressLength+4+1 { + // tron address should should have 21 bytes (including prefix) + 4 checksum + if len(decodeCheck) != AddressLength+4 { return nil, fmt.Errorf("invalid address length: %d", len(decodeCheck)) } diff --git a/core/services/keystore/keys/tronkey/account_test.go b/core/services/keystore/keys/tronkey/account_test.go index 08ac567859d..6047830a717 100644 --- a/core/services/keystore/keys/tronkey/account_test.go +++ b/core/services/keystore/keys/tronkey/account_test.go @@ -103,6 +103,40 @@ func TestAddress(t *testing.T) { }) } +func TestHexToAddress(t *testing.T) { + t.Run("Valid Hex Addresses", func(t *testing.T) { + validHexAddresses := []string{ + "41a614f803b6fd780986a42c78ec9c7f77e6ded13c", + "41b2a2e1b2e1b2e1b2e1b2e1b2e1b2e1b2e1b2e1b2", + "41c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", + } + + for _, hexStr := range validHexAddresses { + t.Run(hexStr, func(t *testing.T) { + addr, err := HexToAddress(hexStr) + require.Nil(t, err) + require.Equal(t, "0x"+hexStr, addr.Hex()) + }) + } + }) + + t.Run("Invalid Hex Addresses", func(t *testing.T) { + invalidHexAddresses := []string{ + "41a614f803b6fd780986a42c78ec9c7f77e6ded13", // Too short + "41b2a2e1b2e1b2e1b2e1b2e1b2e1b2e1b2e1b2e1b2e1b2", // Too long + "41g3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", // Invalid character 'g' + "c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", // Missing prefix '41' + } + + for _, hexStr := range invalidHexAddresses { + t.Run(hexStr, func(t *testing.T) { + _, err := HexToAddress(hexStr) + require.NotNil(t, err) + }) + } + }) +} + // Helper Functions for testing // isValid checks if the address is a valid TRON address