From 833445e9e2062ac96e4c0d703a3ff8ce0189a436 Mon Sep 17 00:00:00 2001 From: Michael FIG Date: Wed, 9 Oct 2024 14:38:05 -0600 Subject: [PATCH] feat(vtransfer): extract base account from parameterized account --- .../cosmos/x/vtransfer/ibc_middleware_test.go | 7 +- golang/cosmos/x/vtransfer/keeper/keeper.go | 84 +++++++++++++++---- golang/cosmos/x/vtransfer/types/baseaddr.go | 33 ++++++++ .../cosmos/x/vtransfer/types/baseaddr_test.go | 73 ++++++++++++++++ 4 files changed, 177 insertions(+), 20 deletions(-) create mode 100644 golang/cosmos/x/vtransfer/types/baseaddr.go create mode 100644 golang/cosmos/x/vtransfer/types/baseaddr_test.go diff --git a/golang/cosmos/x/vtransfer/ibc_middleware_test.go b/golang/cosmos/x/vtransfer/ibc_middleware_test.go index ab82e14e09db..23279bf8a636 100644 --- a/golang/cosmos/x/vtransfer/ibc_middleware_test.go +++ b/golang/cosmos/x/vtransfer/ibc_middleware_test.go @@ -331,18 +331,19 @@ func (s *IntegrationTestSuite) TestTransferFromAgdToAgd() { s.Run("TransferFromAgdToAgd", func() { // create a transfer packet's data contents + baseReceiver := s.chainB.SenderAccounts[1].SenderAccount.GetAddress().String() transferData := ibctransfertypes.NewFungibleTokenPacketData( "uosmo", "1000000", s.chainA.SenderAccount.GetAddress().String(), - s.chainB.SenderAccounts[1].SenderAccount.GetAddress().String(), + baseReceiver+"?what=arbitrary-data&why=to-test-bridge-targets", `"This is a JSON memo"`, ) // Register the sender and receiver as bridge targets on their specific // chain. s.RegisterBridgeTarget(s.chainA, transferData.Sender) - s.RegisterBridgeTarget(s.chainB, transferData.Receiver) + s.RegisterBridgeTarget(s.chainB, baseReceiver) s.mintToAddress(s.chainA, s.chainA.SenderAccount.GetAddress(), transferData.Denom, transferData.Amount) @@ -384,7 +385,7 @@ func (s *IntegrationTestSuite) TestTransferFromAgdToAgd() { BlockTime: writeAcknowledgementTime, }, Event: "writeAcknowledgement", - Target: transferData.Receiver, + Target: baseReceiver, Packet: packet, Acknowledgement: ack.Acknowledgement(), }, diff --git a/golang/cosmos/x/vtransfer/keeper/keeper.go b/golang/cosmos/x/vtransfer/keeper/keeper.go index 36701a544e7d..f6cc101efe4d 100644 --- a/golang/cosmos/x/vtransfer/keeper/keeper.go +++ b/golang/cosmos/x/vtransfer/keeper/keeper.go @@ -16,7 +16,9 @@ import ( "github.com/Agoric/agoric-sdk/golang/cosmos/vm" "github.com/Agoric/agoric-sdk/golang/cosmos/x/vibc" vibctypes "github.com/Agoric/agoric-sdk/golang/cosmos/x/vibc/types" + "github.com/Agoric/agoric-sdk/golang/cosmos/x/vtransfer/types" transfertypes "github.com/cosmos/ibc-go/v6/modules/apps/transfer/types" + clienttypes "github.com/cosmos/ibc-go/v6/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v6/modules/core/04-channel/types" porttypes "github.com/cosmos/ibc-go/v6/modules/core/05-port/types" host "github.com/cosmos/ibc-go/v6/modules/core/24-host" @@ -34,6 +36,7 @@ var _ vm.PortHandler = (*Keeper)(nil) const ( watchedAddressStoreKeyPrefix = "watchedAddress/" watchedAddressSentinel = "y" + supplementalDataSeparator = "+" ) // Keeper handles the interceptions from the vtransfer IBC middleware, passing @@ -101,12 +104,48 @@ func (k Keeper) GetReceiverImpl() vibctypes.ReceiverImpl { return k } +// Extract the base address from the packet sender (if senderIsLocal) or +// receiver (if !senderIsLocal), since the local ibcModule doesn't understand +// address parameters. +func (k Keeper) packetWithOnlyBaseAddresses(packet channeltypes.Packet, senderIsLocal bool) channeltypes.Packet { + transferData := transfertypes.FungibleTokenPacketData{} + if err := k.cdc.UnmarshalJSON(packet.GetData(), &transferData); err != nil { + return packet + } + if senderIsLocal { + baseSender, err := types.ExtractBaseAddress(transferData.Sender) + if err == nil { + transferData.Sender = baseSender + } + } else { + baseReceiver, err := types.ExtractBaseAddress(transferData.Receiver) + if err == nil { + transferData.Receiver = baseReceiver + } + } + data, _ := k.cdc.MarshalJSON(&transferData) + height := packet.GetTimeoutHeight() + newPacket := channeltypes.NewPacket( + data, + packet.GetSequence(), + packet.GetSourcePort(), + packet.GetSourceChannel(), + packet.GetDestPort(), + packet.GetDestChannel(), + clienttypes.NewHeight(height.GetRevisionNumber(), height.GetRevisionHeight()), + packet.GetTimeoutTimestamp(), + ) + return newPacket +} + // InterceptOnRecvPacket runs the ibcModule and eventually acknowledges a packet. // Many error acknowledgments are sent synchronously, but most cases instead return nil // to tell the IBC system that acknowledgment is async (i.e., that WriteAcknowledgement // will be called later, after the VM has dealt with the packet). func (k Keeper) InterceptOnRecvPacket(ctx sdk.Context, ibcModule porttypes.IBCModule, packet channeltypes.Packet, relayer sdk.AccAddress) ibcexported.Acknowledgement { - ack := ibcModule.OnRecvPacket(ctx, packet, relayer) + // Pass every (stripped-receiver) inbound to the wrapped IBC module. + strippedPacket := k.packetWithOnlyBaseAddresses(packet, false) + ack := ibcModule.OnRecvPacket(ctx, strippedPacket, relayer) if ack == nil { // Already declared to be an async ack. @@ -136,11 +175,12 @@ func (k Keeper) InterceptOnAcknowledgementPacket( acknowledgement []byte, relayer sdk.AccAddress, ) error { - // Pass every acknowledgement to the wrapped IBC module. - modErr := ibcModule.OnAcknowledgementPacket(ctx, packet, acknowledgement, relayer) + // Pass every (stripped-sender) acknowledgement to the wrapped IBC module. + strippedPacket := k.packetWithOnlyBaseAddresses(packet, true) + modErr := ibcModule.OnAcknowledgementPacket(ctx, strippedPacket, acknowledgement, relayer) // If the sender is not a targeted account, we're done. - sender, _, err := k.parseTransfer(ctx, packet) + sender, _, err := k.findTransferTargets(ctx, packet) if err != nil || sender == "" { return modErr } @@ -163,11 +203,12 @@ func (k Keeper) InterceptOnTimeoutPacket( packet channeltypes.Packet, relayer sdk.AccAddress, ) error { - // Pass every timeout to the wrapped IBC module. - modErr := ibcModule.OnTimeoutPacket(ctx, packet, relayer) + // Pass every (stripped-sender) timeout to the wrapped IBC module. + strippedPacket := k.packetWithOnlyBaseAddresses(packet, true) + modErr := ibcModule.OnTimeoutPacket(ctx, strippedPacket, relayer) // If the sender is not a targeted account, we're done. - sender, _, err := k.parseTransfer(ctx, packet) + sender, _, err := k.findTransferTargets(ctx, packet) if err != nil || sender == "" { return modErr } @@ -185,7 +226,7 @@ func (k Keeper) InterceptOnTimeoutPacket( // InterceptWriteAcknowledgement checks to see if the packet's receiver is a // targeted account, and if so, delegates to the VM. func (k Keeper) InterceptWriteAcknowledgement(ctx sdk.Context, chanCap *capabilitytypes.Capability, packet ibcexported.PacketI, ack ibcexported.Acknowledgement) error { - _, receiver, err := k.parseTransfer(ctx, packet) + _, receiver, err := k.findTransferTargets(ctx, packet) if err != nil || receiver == "" { // We can't parse, but that means just to ack directly. return k.WriteAcknowledgement(ctx, chanCap, packet, ack) @@ -200,27 +241,36 @@ func (k Keeper) InterceptWriteAcknowledgement(ctx sdk.Context, chanCap *capabili return nil } -// parseTransfer checks if a packet's sender and/or receiver are targeted accounts. -func (k Keeper) parseTransfer(ctx sdk.Context, packet ibcexported.PacketI) (string, string, error) { +// findTransferTargets checks if a packet's sender and/or receiver correspond to targeted accounts. +func (k Keeper) findTransferTargets(ctx sdk.Context, packet ibcexported.PacketI) (string, string, error) { var transferData transfertypes.FungibleTokenPacketData err := k.cdc.UnmarshalJSON(packet.GetData(), &transferData) if err != nil { return "", "", err } - var sender string - var receiver string + // Extract the base addresses from the transferData. + senderTarget, err := types.ExtractBaseAddress(transferData.Sender) + if err != nil { + senderTarget = transferData.Sender + } + receiverTarget, err := types.ExtractBaseAddress(transferData.Receiver) + if err != nil { + receiverTarget = transferData.Receiver + } prefixStore := prefix.NewStore( ctx.KVStore(k.key), []byte(watchedAddressStoreKeyPrefix), ) - if prefixStore.Has([]byte(transferData.Sender)) { - sender = transferData.Sender + if !prefixStore.Has([]byte(senderTarget)) { + // Not a targeted sender. + senderTarget = "" } - if prefixStore.Has([]byte(transferData.Receiver)) { - receiver = transferData.Receiver + if !prefixStore.Has([]byte(receiverTarget)) { + // Not a targeted receiver. + receiverTarget = "" } - return sender, receiver, nil + return senderTarget, receiverTarget, nil } // GetWatchedAdresses returns the watched addresses from the keeper as a slice diff --git a/golang/cosmos/x/vtransfer/types/baseaddr.go b/golang/cosmos/x/vtransfer/types/baseaddr.go new file mode 100644 index 000000000000..a1516643d6f9 --- /dev/null +++ b/golang/cosmos/x/vtransfer/types/baseaddr.go @@ -0,0 +1,33 @@ +package types + +import ( + "fmt" + "net/url" + "strings" +) + +// ExtractBaseAddress extracts the base address from a parameterized address. +func ExtractBaseAddress(addr string) (string, error) { + rawParsed, err := url.Parse(addr) + if err != nil { + return "", err + } + + parsed := url.URL{ + Path: strings.TrimPrefix(rawParsed.Path, "/"), + RawPath: rawParsed.RawPath, + RawQuery: rawParsed.RawQuery, + Fragment: rawParsed.Fragment, + RawFragment: rawParsed.RawFragment, + } + if *rawParsed != parsed { + return "", fmt.Errorf("address must be relative path with optional query and fragment, got %s", addr) + } + + baseAddr, _, _ := strings.Cut(parsed.Path, "/") + if baseAddr == "" { + return "", fmt.Errorf("base address cannot be empty") + } + + return baseAddr, nil +} diff --git a/golang/cosmos/x/vtransfer/types/baseaddr_test.go b/golang/cosmos/x/vtransfer/types/baseaddr_test.go new file mode 100644 index 000000000000..8767c62afa9a --- /dev/null +++ b/golang/cosmos/x/vtransfer/types/baseaddr_test.go @@ -0,0 +1,73 @@ +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Agoric/agoric-sdk/golang/cosmos/x/vtransfer/types" +) + +func TestExtractBaseAddress(t *testing.T) { + bases := []struct { + name string + addr string + }{ + {"agoric address", "agoric1abcdefghiteaneas"}, + {"cosmos address", "cosmos1abcdeffiharceuht"}, + {"hex address", "0xabcdef198189818c93839ibia"}, + } + + prefixes := []struct { + prefix string + baseIsWrong bool + isErr bool + }{ + {"", false, false}, + {"/", false, true}, + {"orch:/", false, true}, + {"unexpected", true, false}, + {"norch:/", false, true}, + {"orch:", false, true}, + {"norch:", false, true}, + {"\x01", false, true}, + } + + suffixes := []struct { + suffix string + baseIsWrong bool + isErr bool + }{ + {"", false, false}, + {"/", false, false}, + {"/sub/account", false, false}, + {"?query=something&k=v&k2=v2", false, false}, + {"?query=something&k=v&k2=v2#fragment", false, false}, + {"unexpected", true, false}, + {"\x01", false, true}, + } + + for _, b := range bases { + b := b + for _, p := range prefixes { + p := p + for _, s := range suffixes { + s := s + t.Run(b.name+" "+p.prefix+" "+s.suffix, func(t *testing.T) { + addr := p.prefix + b.addr + s.suffix + addr, err := types.ExtractBaseAddress(addr) + if p.isErr || s.isErr { + require.Error(t, err) + } else { + require.NoError(t, err) + if p.baseIsWrong || s.baseIsWrong { + require.NotEqual(t, b.addr, addr) + } else { + require.Equal(t, b.addr, addr) + } + } + }) + } + } + } +}