diff --git a/Makefile b/Makefile index f4055e9f0a..e26233681a 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,7 @@ test-short: ## test-race: Run unit tests in race mode. test-race: @echo "--> Running tests in race mode" - @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestConcurrentTxSubmission" + @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestConcurrentTxSubmission|TestTxClientTestSuite" .PHONY: test-race ## test-bench: Run unit tests in bench mode. diff --git a/pkg/user/account.go b/pkg/user/account.go new file mode 100644 index 0000000000..8c52d195a0 --- /dev/null +++ b/pkg/user/account.go @@ -0,0 +1,79 @@ +package user + +import ( + "context" + + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + "github.com/cosmos/cosmos-sdk/types" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "google.golang.org/grpc" +) + +type Account struct { + name string + address types.AccAddress + pubKey cryptotypes.PubKey + accountNumber uint64 + + // the signers local view of the sequence number + sequence uint64 +} + +func NewAccount(keyName string, accountNumber, sequenceNumber uint64) *Account { + return &Account{ + name: keyName, + accountNumber: accountNumber, + sequence: sequenceNumber, + } +} + +func (a Account) Name() string { + return a.name +} + +func (a Account) Address() types.AccAddress { + return a.address +} + +func (a Account) PubKey() cryptotypes.PubKey { + return a.pubKey +} + +// Sequence returns the sequence number of the account. +// This is locally tracked +func (a Account) Sequence() uint64 { + return a.sequence +} + +func (a *Account) Copy() *Account { + return &Account{ + name: a.name, + address: a.address, + pubKey: a.pubKey, + accountNumber: a.accountNumber, + sequence: a.sequence, + } +} + +// QueryAccountInfo fetches the account number and sequence number from the celestia-app node. +func QueryAccountInfo(ctx context.Context, conn *grpc.ClientConn, registry codectypes.InterfaceRegistry, address types.AccAddress) (accNum uint64, seqNum uint64, err error) { + qclient := authtypes.NewQueryClient(conn) + // TODO: ideally we add a way to prove that the accounts rather than simply trusting the full node we are connected with + resp, err := qclient.Account( + ctx, + &authtypes.QueryAccountRequest{Address: address.String()}, + ) + if err != nil { + return accNum, seqNum, err + } + + var acc authtypes.AccountI + err = registry.UnpackAny(resp.Account, &acc) + if err != nil { + return accNum, seqNum, err + } + + accNum, seqNum = acc.GetAccountNumber(), acc.GetSequence() + return accNum, seqNum, nil +} diff --git a/pkg/user/legacy_signer.go b/pkg/user/legacy_signer.go new file mode 100644 index 0000000000..618344dc5d --- /dev/null +++ b/pkg/user/legacy_signer.go @@ -0,0 +1,615 @@ +package user + +import ( + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/celestiaorg/celestia-app/app/encoding" + apperrors "github.com/celestiaorg/celestia-app/app/errors" + blob "github.com/celestiaorg/celestia-app/x/blob/types" + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" + "github.com/cosmos/cosmos-sdk/crypto/keyring" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" + sdktypes "github.com/cosmos/cosmos-sdk/types" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + tmtypes "github.com/tendermint/tendermint/types" + "google.golang.org/grpc" +) + +// Signer is an abstraction for building, signing, and broadcasting Celestia transactions +type Signer struct { + keys keyring.Keyring + address sdktypes.AccAddress + enc client.TxConfig + grpc *grpc.ClientConn + pk cryptotypes.PubKey + chainID string + accountNumber uint64 + + mtx sync.RWMutex + // how often to poll the network for confirmation of a transaction + pollTime time.Duration + // the signers local view of the sequence number + localSequence uint64 + // the chains last known sequence number + networkSequence uint64 + // lookup map of all pending and yet to be confirmed outbound transactions + outboundSequences map[uint64]struct{} + // a reverse map for confirming which sequence numbers have been committed + reverseTxHashSequenceMap map[string]uint64 +} + +// NewSigner returns a new signer using the provided keyring +func NewSigner( + keys keyring.Keyring, + conn *grpc.ClientConn, + address sdktypes.AccAddress, + enc client.TxConfig, + chainID string, + accountNumber uint64, + sequence uint64, +) (*Signer, error) { + // check that the address exists + record, err := keys.KeyByAddress(address) + if err != nil { + return nil, err + } + + pk, err := record.GetPubKey() + if err != nil { + return nil, err + } + + return &Signer{ + keys: keys, + address: address, + grpc: conn, + enc: enc, + pk: pk, + chainID: chainID, + accountNumber: accountNumber, + localSequence: sequence, + networkSequence: sequence, + pollTime: DefaultPollTime, + outboundSequences: make(map[uint64]struct{}), + reverseTxHashSequenceMap: make(map[string]uint64), + }, nil +} + +// SetupSingleSigner sets up a signer based on the provided keyring. The keyring +// must contain exactly one key. It extracts the address from the key and uses +// the grpc connection to populate the chainID, account number, and sequence +// number. +func SetupSingleSigner(ctx context.Context, keys keyring.Keyring, conn *grpc.ClientConn, encCfg encoding.Config) (*Signer, error) { + records, err := keys.List() + if err != nil { + return nil, err + } + + if len(records) != 1 { + return nil, errors.New("keyring must contain exactly one key") + } + + address, err := records[0].GetAddress() + if err != nil { + return nil, err + } + + return SetupSigner(ctx, keys, conn, address, encCfg) +} + +// SetupSigner uses the underlying grpc connection to populate the chainID, accountNumber and sequence number of the +// account. +func SetupSigner( + ctx context.Context, + keys keyring.Keyring, + conn *grpc.ClientConn, + address sdktypes.AccAddress, + encCfg encoding.Config, +) (*Signer, error) { + resp, err := tmservice.NewServiceClient(conn).GetLatestBlock(ctx, &tmservice.GetLatestBlockRequest{}) + if err != nil { + return nil, err + } + + chainID := resp.SdkBlock.Header.ChainID + accNum, seqNum, err := QueryAccount(ctx, conn, encCfg, address.String()) + if err != nil { + return nil, err + } + + return NewSigner(keys, conn, address, encCfg.TxConfig, chainID, accNum, seqNum) +} + +// SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions +// may be provided to set the fee and gas limit. +func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { + tx, err := s.CreateTx(msgs, opts...) + if err != nil { + return nil, err + } + + resp, err := s.BroadcastTx(ctx, tx) + if err != nil { + return resp, err + } + + return s.ConfirmTx(ctx, resp.TxHash) +} + +// SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. +// TxOptions may be provided to set the fee and gas limit. +func (s *Signer) SubmitPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + resp, err := s.broadcastPayForBlob(ctx, blobs, opts...) + if err != nil { + return resp, err + } + + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *Signer) broadcastPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, seqNum, err := s.createPayForBlobs(blobs, opts...) + if err != nil { + return nil, err + } + + return s.broadcastTx(ctx, txBytes, seqNum) +} + +// CreateTx forms a transaction from the provided messages and signs it. TxOptions may be optionally +// used to set the gas limit and fee. +func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.createTx(msgs, opts...) +} + +func (s *Signer) createTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { + txBuilder := s.txBuilder(opts...) + if err := txBuilder.SetMsgs(msgs...); err != nil { + return nil, err + } + + if err := s.signTransaction(txBuilder, s.getAndIncrementSequence()); err != nil { + return nil, err + } + + return txBuilder.GetTx(), nil +} + +func (s *Signer) CreatePayForBlob(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + blobTx, _, err := s.createPayForBlobs(blobs, opts...) + return blobTx, err +} + +func (s *Signer) createPayForBlobs(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, uint64, error) { + msg, err := blob.NewMsgPayForBlobs(s.address.String(), blobs...) + if err != nil { + return nil, 0, err + } + + tx, err := s.createTx([]sdktypes.Msg{msg}, opts...) + if err != nil { + return nil, 0, err + } + + seqNum, err := getSequenceNumber(tx) + if err != nil { + panic(err) + } + + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, 0, err + } + + blobTx, err := tmtypes.MarshalBlobTx(txBytes, blobs...) + return blobTx, seqNum, err +} + +func (s *Signer) EncodeTx(tx sdktypes.Tx) ([]byte, error) { + return s.enc.TxEncoder()(tx) +} + +func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { + tx, err := s.enc.TxDecoder()(txBytes) + if err != nil { + return nil, err + } + authTx, ok := tx.(authsigning.Tx) + if !ok { + return nil, errors.New("not an authsigning transaction") + } + return authTx, nil +} + +// BroadcastTx submits the provided transaction bytes to the chain and returns the response. +func (s *Signer) BroadcastTx(ctx context.Context, tx authsigning.Tx) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, err + } + sequence, err := getSequenceNumber(tx) + if err != nil { + return nil, err + } + return s.broadcastTx(ctx, txBytes, sequence) +} + +// CONTRACT: assumes the caller has the lock +func (s *Signer) broadcastTx(ctx context.Context, txBytes []byte, sequence uint64) (*sdktypes.TxResponse, error) { + if _, exists := s.outboundSequences[sequence]; exists { + return s.retryBroadcastingTx(ctx, txBytes, sequence+1) + } + + if sequence < s.networkSequence { + s.localSequence = s.networkSequence + return s.retryBroadcastingTx(ctx, txBytes, s.localSequence) + } + + txClient := sdktx.NewServiceClient(s.grpc) + resp, err := txClient.BroadcastTx( + ctx, + &sdktx.BroadcastTxRequest{ + Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, + TxBytes: txBytes, + }, + ) + if err != nil { + return nil, err + } + if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { + // extract what the lastCommittedNonce on chain is + nextSequence, err := apperrors.ParseExpectedSequence(resp.TxResponse.RawLog) + if err != nil { + return nil, fmt.Errorf("parsing nonce mismatch upon retry: %w", err) + } + s.networkSequence = nextSequence + s.localSequence = nextSequence + // FIXME: We can't actually resign the transaction. A malicious node + // may manipulate us into signing the same transaction several times + // and then executing them. We need some proof of what the last network + // sequence is rather than relying on an error provided by the node + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + // Ref: https://github.com/celestiaorg/celestia-app/issues/3256 + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + } else if resp.TxResponse.Code == abci.CodeTypeOK { + s.outboundSequences[sequence] = struct{}{} + s.reverseTxHashSequenceMap[resp.TxResponse.TxHash] = sequence + return resp.TxResponse, nil + } + return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) +} + +// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the +// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction +func (s *Signer) retryBroadcastingTx(ctx context.Context, txBytes []byte, newSequenceNumber uint64) (*sdktypes.TxResponse, error) { + blobTx, isBlobTx := tmtypes.UnmarshalBlobTx(txBytes) + if isBlobTx { + txBytes = blobTx.Tx + } + tx, err := s.DecodeTx(txBytes) + if err != nil { + return nil, err + } + txBuilder := s.txBuilder() + if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { + return nil, err + } + if granter := tx.FeeGranter(); granter != nil { + txBuilder.SetFeeGranter(granter) + } + if payer := tx.FeePayer(); payer != nil { + txBuilder.SetFeePayer(payer) + } + if memo := tx.GetMemo(); memo != "" { + txBuilder.SetMemo(memo) + } + if fee := tx.GetFee(); fee != nil { + txBuilder.SetFeeAmount(fee) + } + if gas := tx.GetGas(); gas > 0 { + txBuilder.SetGasLimit(gas) + } + + if err := s.signTransaction(txBuilder, newSequenceNumber); err != nil { + return nil, fmt.Errorf("resigning transaction: %w", err) + } + + newTxBytes, err := s.EncodeTx(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + // rewrap the blob tx if it was originally a blob tx + if isBlobTx { + newTxBytes, err = tmtypes.MarshalBlobTx(newTxBytes, blobTx.Blobs...) + if err != nil { + return nil, err + } + } + + return s.broadcastTx(ctx, newTxBytes, newSequenceNumber) +} + +// ConfirmTx periodically pings the provided node for the commitment of a transaction by its +// hash. It will continually loop until the context is cancelled, the tx is found or an error +// is encountered. +func (s *Signer) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { + txClient := sdktx.NewServiceClient(s.grpc) + + pollTime := s.getPollTime() + timer := time.NewTimer(0) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return &sdktypes.TxResponse{}, ctx.Err() + case <-timer.C: + resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) + if err == nil { + if resp.TxResponse.Code != 0 { + s.updateNetworkSequence(txHash, false) + return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + } + s.updateNetworkSequence(txHash, true) + return resp.TxResponse, nil + } + // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other + // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the + // signer more information on the status of their transaction and then update the logic here + if !strings.Contains(err.Error(), "not found") { + return &sdktypes.TxResponse{}, err + } + + timer.Reset(pollTime) + } + } +} + +func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (uint64, error) { + txBuilder := s.txBuilder(opts...) + if err := txBuilder.SetMsgs(msgs...); err != nil { + return 0, err + } + + if err := s.signTransaction(txBuilder, s.LocalSequence()); err != nil { + return 0, err + } + + txBytes, err := s.enc.TxEncoder()(txBuilder.GetTx()) + if err != nil { + return 0, err + } + + resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ + TxBytes: txBytes, + }) + if err != nil { + return 0, err + } + + return resp.GasInfo.GasUsed, nil +} + +// ChainID returns the chain ID of the signer. +func (s *Signer) ChainID() string { + return s.chainID +} + +// AccountNumber returns the account number of the signer. +func (s *Signer) AccountNumber() uint64 { + return s.accountNumber +} + +// Address returns the address of the signer. +func (s *Signer) Address() sdktypes.AccAddress { + return s.address +} + +// SetPollTime sets how often the signer should poll for the confirmation of the transaction +func (s *Signer) SetPollTime(pollTime time.Duration) { + s.mtx.Lock() + defer s.mtx.Unlock() + s.pollTime = pollTime +} + +func (s *Signer) getPollTime() time.Duration { + s.mtx.Lock() + defer s.mtx.Unlock() + return s.pollTime +} + +// PubKey returns the public key of the signer +func (s *Signer) PubKey() cryptotypes.PubKey { + return s.pk +} + +// DEPRECATED: use Sequence instead +func (s *Signer) GetSequence() uint64 { + return s.getAndIncrementSequence() +} + +// LocalSequence returns the next sequence number of the signers +// locally saved +func (s *Signer) LocalSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.localSequence +} + +func (s *Signer) NetworkSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.networkSequence +} + +// getAndIncrementSequence gets the latest signed sequence and increments the +// local sequence number +func (s *Signer) getAndIncrementSequence() uint64 { + defer func() { s.localSequence++ }() + return s.localSequence +} + +// ForceSetSequence manually overrides the current local and network level +// sequence number. Be careful when invoking this as it may cause the +// transactions to reject the sequence if it doesn't match the one in state +func (s *Signer) ForceSetSequence(seq uint64) { + s.mtx.Lock() + defer s.mtx.Unlock() + s.localSequence = seq + s.networkSequence = seq +} + +// updateNetworkSequence is called once a transaction is confirmed +// and updates the chains last known sequence number +func (s *Signer) updateNetworkSequence(txHash string, success bool) { + s.mtx.Lock() + defer s.mtx.Unlock() + sequence, exists := s.reverseTxHashSequenceMap[txHash] + if !exists { + return + } + if success && sequence >= s.networkSequence { + s.networkSequence = sequence + 1 + } + delete(s.outboundSequences, sequence) + delete(s.reverseTxHashSequenceMap, txHash) +} + +// Keyring exposes the signers underlying keyring +func (s *Signer) Keyring() keyring.Keyring { + return s.keys +} + +func (s *Signer) signTransaction(builder client.TxBuilder, sequence uint64) error { + signers := builder.GetTx().GetSigners() + if len(signers) != 1 { + return fmt.Errorf("expected 1 signer, got %d", len(signers)) + } + + if !s.address.Equals(signers[0]) { + return fmt.Errorf("expected signer %s, got %s", s.address.String(), signers[0].String()) + } + + // To ensure we have the correct bytes to sign over we produce + // a dry run of the signing data + err := builder.SetSignatures(s.getSignatureV2(sequence, nil)) + if err != nil { + return fmt.Errorf("error setting draft signatures: %w", err) + } + + // now we can use the data to produce the signature from the signer + signature, err := s.createSignature(builder, sequence) + if err != nil { + return fmt.Errorf("error creating signature: %w", err) + } + + err = builder.SetSignatures(s.getSignatureV2(sequence, signature)) + if err != nil { + return fmt.Errorf("error setting signatures: %w", err) + } + + return nil +} + +func (s *Signer) createSignature(builder client.TxBuilder, sequence uint64) ([]byte, error) { + signerData := authsigning.SignerData{ + Address: s.address.String(), + ChainID: s.ChainID(), + AccountNumber: s.accountNumber, + Sequence: sequence, + PubKey: s.pk, + } + + bytesToSign, err := s.enc.SignModeHandler().GetSignBytes( + signing.SignMode_SIGN_MODE_DIRECT, + signerData, + builder.GetTx(), + ) + if err != nil { + return nil, fmt.Errorf("error getting sign bytes: %w", err) + } + + signature, _, err := s.keys.SignByAddress(s.address, bytesToSign) + if err != nil { + return nil, fmt.Errorf("error signing bytes: %w", err) + } + + return signature, nil +} + +// txBuilder returns the default sdk Tx builder using the celestia-app encoding config +func (s *Signer) txBuilder(opts ...TxOption) client.TxBuilder { + builder := s.enc.NewTxBuilder() + for _, opt := range opts { + builder = opt(builder) + } + return builder +} + +// QueryAccount fetches the account number and sequence number from the celestia-app node. +func QueryAccount(ctx context.Context, conn *grpc.ClientConn, encCfg encoding.Config, address string) (accNum uint64, seqNum uint64, err error) { + qclient := authtypes.NewQueryClient(conn) + resp, err := qclient.Account( + ctx, + &authtypes.QueryAccountRequest{Address: address}, + ) + if err != nil { + return accNum, seqNum, err + } + + var acc authtypes.AccountI + err = encCfg.InterfaceRegistry.UnpackAny(resp.Account, &acc) + if err != nil { + return accNum, seqNum, err + } + + accNum, seqNum = acc.GetAccountNumber(), acc.GetSequence() + return accNum, seqNum, nil +} + +func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.SignatureV2 { + sigV2 := signing.SignatureV2{ + Data: &signing.SingleSignatureData{ + SignMode: signing.SignMode_SIGN_MODE_DIRECT, + Signature: signature, + }, + Sequence: sequence, + } + if sequence == 0 { + sigV2.PubKey = s.pk + } + return sigV2 +} + +func getSequenceNumber(tx authsigning.Tx) (uint64, error) { + sigs, err := tx.GetSignaturesV2() + if err != nil { + return 0, err + } + if len(sigs) > 1 { + return 0, fmt.Errorf("only a signle signature is supported, got %d", len(sigs)) + } + + return sigs[0].Sequence, nil +} diff --git a/pkg/user/signer_test.go b/pkg/user/legacy_signer_test.go similarity index 100% rename from pkg/user/signer_test.go rename to pkg/user/legacy_signer_test.go diff --git a/pkg/user/signer.go b/pkg/user/signer.go index 0f84556bed..1bc450a7bc 100644 --- a/pkg/user/signer.go +++ b/pkg/user/signer.go @@ -1,219 +1,105 @@ package user import ( - "context" "errors" "fmt" - "strings" - "sync" - "time" - "github.com/celestiaorg/celestia-app/app/encoding" - apperrors "github.com/celestiaorg/celestia-app/app/errors" - blob "github.com/celestiaorg/celestia-app/x/blob/types" "github.com/cosmos/cosmos-sdk/client" - "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" "github.com/cosmos/cosmos-sdk/crypto/keyring" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdktypes "github.com/cosmos/cosmos-sdk/types" - sdktx "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" - authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" - abci "github.com/tendermint/tendermint/abci/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" tmtypes "github.com/tendermint/tendermint/types" - "google.golang.org/grpc" -) -const DefaultPollTime = 3 * time.Second - -// Signer is an abstraction for building, signing, and broadcasting Celestia transactions -type Signer struct { - keys keyring.Keyring - address sdktypes.AccAddress - enc client.TxConfig - grpc *grpc.ClientConn - pk cryptotypes.PubKey - chainID string - accountNumber uint64 - - mtx sync.RWMutex - // how often to poll the network for confirmation of a transaction - pollTime time.Duration - // the signers local view of the sequence number - localSequence uint64 - // the chains last known sequence number - networkSequence uint64 - // lookup map of all pending and yet to be confirmed outbound transactions - outboundSequences map[uint64]struct{} - // a reverse map for confirming which sequence numbers have been committed - reverseTxHashSequenceMap map[string]uint64 -} + blobtypes "github.com/celestiaorg/celestia-app/x/blob/types" +) -// NewSigner returns a new signer using the provided keyring -func NewSigner( +// TxSigner is struct for building and signing Celestia transactions +// It supports multiple accounts wrapping a Keyring. +// NOTE: All transactions may only have a single signer +// TxSigner is not thread-safe. +type TxSigner struct { + keys keyring.Keyring + enc client.TxConfig + chainID string + // FIXME: the signer is currently incapable of detecting an appversion + // change and could produce incorrect PFBs if it the network is at an + // appVersion that the signer does not support + appVersion uint64 + + // set of accounts that the signer can manage. Should match the keys on the keyring + accounts map[string]*Account + addressToAccountMap map[string]string +} + +// NewTxSigner returns a new signer using the provided keyring +// There must be at least one account in the keyring +// The first account provided will be set as the default +func NewTxSigner( keys keyring.Keyring, - conn *grpc.ClientConn, - address sdktypes.AccAddress, - enc client.TxConfig, + encCfg client.TxConfig, chainID string, - accountNumber uint64, - sequence uint64, -) (*Signer, error) { - // check that the address exists - record, err := keys.KeyByAddress(address) - if err != nil { - return nil, err - } - - pk, err := record.GetPubKey() - if err != nil { - return nil, err - } - - return &Signer{ - keys: keys, - address: address, - grpc: conn, - enc: enc, - pk: pk, - chainID: chainID, - accountNumber: accountNumber, - localSequence: sequence, - networkSequence: sequence, - pollTime: DefaultPollTime, - outboundSequences: make(map[uint64]struct{}), - reverseTxHashSequenceMap: make(map[string]uint64), - }, nil -} - -// SetupSingleSigner sets up a signer based on the provided keyring. The keyring -// must contain exactly one key. It extracts the address from the key and uses -// the grpc connection to populate the chainID, account number, and sequence -// number. -func SetupSingleSigner(ctx context.Context, keys keyring.Keyring, conn *grpc.ClientConn, encCfg encoding.Config) (*Signer, error) { - records, err := keys.List() - if err != nil { - return nil, err - } - - if len(records) != 1 { - return nil, errors.New("keyring must contain exactly one key") - } - - address, err := records[0].GetAddress() - if err != nil { - return nil, err - } - - return SetupSigner(ctx, keys, conn, address, encCfg) -} - -// SetupSigner uses the underlying grpc connection to populate the chainID, accountNumber and sequence number of the -// account. -func SetupSigner( - ctx context.Context, - keys keyring.Keyring, - conn *grpc.ClientConn, - address sdktypes.AccAddress, - encCfg encoding.Config, -) (*Signer, error) { - resp, err := tmservice.NewServiceClient(conn).GetLatestBlock(ctx, &tmservice.GetLatestBlockRequest{}) - if err != nil { - return nil, err - } - - chainID := resp.SdkBlock.Header.ChainID - accNum, seqNum, err := QueryAccount(ctx, conn, encCfg, address.String()) - if err != nil { - return nil, err - } - - return NewSigner(keys, conn, address, encCfg.TxConfig, chainID, accNum, seqNum) -} - -// SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions -// may be provided to set the fee and gas limit. -func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { - tx, err := s.CreateTx(msgs, opts...) - if err != nil { - return nil, err - } - - resp, err := s.BroadcastTx(ctx, tx) - if err != nil { - return resp, err - } - - return s.ConfirmTx(ctx, resp.TxHash) -} - -// SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. -// TxOptions may be provided to set the fee and gas limit. -func (s *Signer) SubmitPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { - resp, err := s.broadcastPayForBlob(ctx, blobs, opts...) - if err != nil { - return resp, err + appVersion uint64, + accounts ...*Account, +) (*TxSigner, error) { + s := &TxSigner{ + keys: keys, + chainID: chainID, + enc: encCfg, + accounts: make(map[string]*Account), + addressToAccountMap: make(map[string]string), + appVersion: appVersion, + } + + for _, acc := range accounts { + if err := s.AddAccount(acc); err != nil { + return nil, err + } } - return s.ConfirmTx(ctx, resp.TxHash) + return s, nil } -func (s *Signer) broadcastPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { - s.mtx.Lock() - defer s.mtx.Unlock() - txBytes, seqNum, err := s.createPayForBlobs(blobs, opts...) +// CreateTx forms a transaction from the provided messages and signs it. +// TxOptions may be optionally used to set the gas limit and fee. +func (s *TxSigner) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) { + tx, _, _, err := s.SignTx(msgs, opts...) if err != nil { return nil, err } - - return s.broadcastTx(ctx, txBytes, seqNum) + return s.EncodeTx(tx) } -// CreateTx forms a transaction from the provided messages and signs it. TxOptions may be optionally -// used to set the gas limit and fee. -func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { - s.mtx.Lock() - defer s.mtx.Unlock() - - return s.createTx(msgs, opts...) -} - -func (s *Signer) createTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { +func (s *TxSigner) SignTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, string, uint64, error) { txBuilder := s.txBuilder(opts...) if err := txBuilder.SetMsgs(msgs...); err != nil { - return nil, err + return nil, "", 0, err } - if err := s.signTransaction(txBuilder, s.getAndIncrementSequence()); err != nil { - return nil, err + signer, sequence, err := s.signTransaction(txBuilder) + if err != nil { + return nil, "", 0, err } - return txBuilder.GetTx(), nil -} - -func (s *Signer) CreatePayForBlob(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, error) { - s.mtx.Lock() - defer s.mtx.Unlock() - blobTx, _, err := s.createPayForBlobs(blobs, opts...) - return blobTx, err + return txBuilder.GetTx(), signer, sequence, nil } -func (s *Signer) createPayForBlobs(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, uint64, error) { - msg, err := blob.NewMsgPayForBlobs(s.address.String(), blobs...) - if err != nil { - return nil, 0, err +func (s *TxSigner) CreatePayForBlobs(accountName string, blobs []*tmproto.Blob, opts ...TxOption) ([]byte, uint64, error) { + acc, exists := s.accounts[accountName] + if !exists { + return nil, 0, fmt.Errorf("account %s not found", accountName) } - tx, err := s.createTx([]sdktypes.Msg{msg}, opts...) + msg, err := blobtypes.NewMsgPayForBlobs(acc.address.String(), blobs...) if err != nil { return nil, 0, err } - seqNum, err := getSequenceNumber(tx) + tx, _, sequence, err := s.SignTx([]sdktypes.Msg{msg}, opts...) if err != nil { - panic(err) + return nil, 0, err } txBytes, err := s.EncodeTx(tx) @@ -222,14 +108,14 @@ func (s *Signer) createPayForBlobs(blobs []*tmproto.Blob, opts ...TxOption) ([]b } blobTx, err := tmtypes.MarshalBlobTx(txBytes, blobs...) - return blobTx, seqNum, err + return blobTx, sequence, err } -func (s *Signer) EncodeTx(tx sdktypes.Tx) ([]byte, error) { +func (s *TxSigner) EncodeTx(tx sdktypes.Tx) ([]byte, error) { return s.enc.TxEncoder()(tx) } -func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { +func (s *TxSigner) DecodeTx(txBytes []byte) (authsigning.Tx, error) { tx, err := s.enc.TxDecoder()(txBytes) if err != nil { return nil, err @@ -241,306 +127,132 @@ func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { return authTx, nil } -// BroadcastTx submits the provided transaction bytes to the chain and returns the response. -func (s *Signer) BroadcastTx(ctx context.Context, tx authsigning.Tx) (*sdktypes.TxResponse, error) { - s.mtx.Lock() - defer s.mtx.Unlock() - txBytes, err := s.EncodeTx(tx) - if err != nil { - return nil, err - } - sequence, err := getSequenceNumber(tx) - if err != nil { - return nil, err - } - return s.broadcastTx(ctx, txBytes, sequence) +// ChainID returns the chain ID of the signer. +func (s *TxSigner) ChainID() string { + return s.chainID } -// CONTRACT: assumes the caller has the lock -func (s *Signer) broadcastTx(ctx context.Context, txBytes []byte, sequence uint64) (*sdktypes.TxResponse, error) { - if _, exists := s.outboundSequences[sequence]; exists { - return s.retryBroadcastingTx(ctx, txBytes, sequence+1) - } - - if sequence < s.networkSequence { - s.localSequence = s.networkSequence - return s.retryBroadcastingTx(ctx, txBytes, s.localSequence) - } +// Account returns an account of the signer from the key name +func (s *TxSigner) Account(name string) *Account { + return s.accounts[name] +} - txClient := sdktx.NewServiceClient(s.grpc) - resp, err := txClient.BroadcastTx( - ctx, - &sdktx.BroadcastTxRequest{ - Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, - TxBytes: txBytes, - }, - ) - if err != nil { - return nil, err - } - if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { - // extract what the lastCommittedNonce on chain is - nextSequence, err := apperrors.ParseExpectedSequence(resp.TxResponse.RawLog) - if err != nil { - return nil, fmt.Errorf("parsing nonce mismatch upon retry: %w", err) - } - s.networkSequence = nextSequence - s.localSequence = nextSequence - // FIXME: We can't actually resign the transaction. A malicious node - // may manipulate us into signing the same transaction several times - // and then executing them. We need some proof of what the last network - // sequence is rather than relying on an error provided by the node - // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) - // Ref: https://github.com/celestiaorg/celestia-app/issues/3256 - // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) - } else if resp.TxResponse.Code == abci.CodeTypeOK { - s.outboundSequences[sequence] = struct{}{} - s.reverseTxHashSequenceMap[resp.TxResponse.TxHash] = sequence - return resp.TxResponse, nil +// AccountByAddress returns the account associated with the given address +func (s *TxSigner) AccountByAddress(address sdktypes.AccAddress) *Account { + accountName, exists := s.addressToAccountMap[address.String()] + if !exists { + return nil } - return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + return s.accounts[accountName] } -// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the -// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction -func (s *Signer) retryBroadcastingTx(ctx context.Context, txBytes []byte, newSequenceNumber uint64) (*sdktypes.TxResponse, error) { - blobTx, isBlobTx := tmtypes.UnmarshalBlobTx(txBytes) - if isBlobTx { - txBytes = blobTx.Tx - } - tx, err := s.DecodeTx(txBytes) - if err != nil { - return nil, err - } - txBuilder := s.txBuilder() - if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { - return nil, err - } - if granter := tx.FeeGranter(); granter != nil { - txBuilder.SetFeeGranter(granter) - } - if payer := tx.FeePayer(); payer != nil { - txBuilder.SetFeePayer(payer) - } - if memo := tx.GetMemo(); memo != "" { - txBuilder.SetMemo(memo) - } - if fee := tx.GetFee(); fee != nil { - txBuilder.SetFeeAmount(fee) - } - if gas := tx.GetGas(); gas > 0 { - txBuilder.SetGasLimit(gas) +func (s *TxSigner) Accounts() []*Account { + accounts := make([]*Account, len(s.accounts)) + i := 0 + for _, acc := range s.accounts { + accounts[i] = acc + i++ } + return accounts +} - if err := s.signTransaction(txBuilder, newSequenceNumber); err != nil { - return nil, fmt.Errorf("resigning transaction: %w", err) +func (s *TxSigner) findAccount(txbuilder client.TxBuilder) (*Account, error) { + signers := txbuilder.GetTx().GetSigners() + if len(signers) == 0 { + return nil, fmt.Errorf("message has no signer") } - - newTxBytes, err := s.EncodeTx(txBuilder.GetTx()) - if err != nil { - return nil, err + accountName, exists := s.addressToAccountMap[signers[0].String()] + if !exists { + return nil, fmt.Errorf("account %s not found", signers[0].String()) } + return s.accounts[accountName], nil +} - // rewrap the blob tx if it was originally a blob tx - if isBlobTx { - newTxBytes, err = tmtypes.MarshalBlobTx(newTxBytes, blobTx.Blobs...) - if err != nil { - return nil, err - } +func (s *TxSigner) IncrementSequence(accountName string) error { + acc, exists := s.accounts[accountName] + if !exists { + return fmt.Errorf("account %s does not exist", accountName) } - - return s.broadcastTx(ctx, newTxBytes, newSequenceNumber) + acc.sequence++ + return nil } -// ConfirmTx periodically pings the provided node for the commitment of a transaction by its -// hash. It will continually loop until the context is cancelled, the tx is found or an error -// is encountered. -func (s *Signer) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { - txClient := sdktx.NewServiceClient(s.grpc) - - pollTime := s.getPollTime() - timer := time.NewTimer(0) - defer timer.Stop() - - for { - select { - case <-ctx.Done(): - return &sdktypes.TxResponse{}, ctx.Err() - case <-timer.C: - resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) - if err == nil { - if resp.TxResponse.Code != 0 { - s.updateNetworkSequence(txHash, false) - return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) - } - s.updateNetworkSequence(txHash, true) - return resp.TxResponse, nil - } - // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other - // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the - // signer more information on the status of their transaction and then update the logic here - if !strings.Contains(err.Error(), "not found") { - return &sdktypes.TxResponse{}, err - } - - timer.Reset(pollTime) - } +func (s *TxSigner) SetSequence(accountName string, seq uint64) error { + acc, exists := s.accounts[accountName] + if !exists { + return fmt.Errorf("account %s does not exist", accountName) } + + acc.sequence = seq + return nil } -func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (uint64, error) { - txBuilder := s.txBuilder(opts...) - if err := txBuilder.SetMsgs(msgs...); err != nil { - return 0, err +func (s *TxSigner) AddAccount(acc *Account) error { + if acc == nil { + return errors.New("account is nil") } - if err := s.signTransaction(txBuilder, s.LocalSequence()); err != nil { - return 0, err + record, err := s.keys.Key(acc.name) + if err != nil { + return fmt.Errorf("retrieving key for account %s: %w", acc.name, err) } - txBytes, err := s.enc.TxEncoder()(txBuilder.GetTx()) + addr, err := record.GetAddress() if err != nil { - return 0, err + return fmt.Errorf("getting address for key %s: %w", acc.pubKey, err) } - resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ - TxBytes: txBytes, - }) + pk, err := record.GetPubKey() if err != nil { - return 0, err + return fmt.Errorf("getting public key for account %s: %w", acc.name, err) } - return resp.GasInfo.GasUsed, nil -} - -// ChainID returns the chain ID of the signer. -func (s *Signer) ChainID() string { - return s.chainID -} - -// AccountNumber returns the account number of the signer. -func (s *Signer) AccountNumber() uint64 { - return s.accountNumber -} - -// Address returns the address of the signer. -func (s *Signer) Address() sdktypes.AccAddress { - return s.address -} - -// SetPollTime sets how often the signer should poll for the confirmation of the transaction -func (s *Signer) SetPollTime(pollTime time.Duration) { - s.mtx.Lock() - defer s.mtx.Unlock() - s.pollTime = pollTime -} - -func (s *Signer) getPollTime() time.Duration { - s.mtx.Lock() - defer s.mtx.Unlock() - return s.pollTime -} - -// PubKey returns the public key of the signer -func (s *Signer) PubKey() cryptotypes.PubKey { - return s.pk -} - -// DEPRECATED: use Sequence instead -func (s *Signer) GetSequence() uint64 { - return s.getAndIncrementSequence() -} - -// LocalSequence returns the next sequence number of the signers -// locally saved -func (s *Signer) LocalSequence() uint64 { - s.mtx.RLock() - defer s.mtx.RUnlock() - return s.localSequence -} - -func (s *Signer) NetworkSequence() uint64 { - s.mtx.RLock() - defer s.mtx.RUnlock() - return s.networkSequence -} - -// getAndIncrementSequence gets the latest signed sequence and increments the -// local sequence number -func (s *Signer) getAndIncrementSequence() uint64 { - defer func() { s.localSequence++ }() - return s.localSequence -} - -// ForceSetSequence manually overrides the current local and network level -// sequence number. Be careful when invoking this as it may cause the -// transactions to reject the sequence if it doesn't match the one in state -func (s *Signer) ForceSetSequence(seq uint64) { - s.mtx.Lock() - defer s.mtx.Unlock() - s.localSequence = seq - s.networkSequence = seq -} - -// updateNetworkSequence is called once a transaction is confirmed -// and updates the chains last known sequence number -func (s *Signer) updateNetworkSequence(txHash string, success bool) { - s.mtx.Lock() - defer s.mtx.Unlock() - sequence, exists := s.reverseTxHashSequenceMap[txHash] - if !exists { - return - } - if success && sequence >= s.networkSequence { - s.networkSequence = sequence + 1 - } - delete(s.outboundSequences, sequence) - delete(s.reverseTxHashSequenceMap, txHash) + acc.address = addr + acc.pubKey = pk + s.accounts[acc.name] = acc + s.addressToAccountMap[addr.String()] = acc.name + return nil } // Keyring exposes the signers underlying keyring -func (s *Signer) Keyring() keyring.Keyring { +func (s *TxSigner) Keyring() keyring.Keyring { return s.keys } -func (s *Signer) signTransaction(builder client.TxBuilder, sequence uint64) error { - signers := builder.GetTx().GetSigners() - if len(signers) != 1 { - return fmt.Errorf("expected 1 signer, got %d", len(signers)) - } - - if !s.address.Equals(signers[0]) { - return fmt.Errorf("expected signer %s, got %s", s.address.String(), signers[0].String()) +func (s *TxSigner) signTransaction(builder client.TxBuilder) (string, uint64, error) { + account, err := s.findAccount(builder) + if err != nil { + return "", 0, err } // To ensure we have the correct bytes to sign over we produce // a dry run of the signing data - err := builder.SetSignatures(s.getSignatureV2(sequence, nil)) + err = builder.SetSignatures(s.getSignatureV2(account.sequence, account.pubKey, nil)) if err != nil { - return fmt.Errorf("error setting draft signatures: %w", err) + return "", 0, fmt.Errorf("error setting draft signatures: %w", err) } // now we can use the data to produce the signature from the signer - signature, err := s.createSignature(builder, sequence) + signature, err := s.createSignature(builder, account, account.sequence) if err != nil { - return fmt.Errorf("error creating signature: %w", err) + return "", 0, fmt.Errorf("error creating signature: %w", err) } - err = builder.SetSignatures(s.getSignatureV2(sequence, signature)) + err = builder.SetSignatures(s.getSignatureV2(account.sequence, account.pubKey, signature)) if err != nil { - return fmt.Errorf("error setting signatures: %w", err) + return "", 0, fmt.Errorf("error setting signatures: %w", err) } - return nil + return account.name, account.sequence, nil } -func (s *Signer) createSignature(builder client.TxBuilder, sequence uint64) ([]byte, error) { +func (s *TxSigner) createSignature(builder client.TxBuilder, account *Account, sequence uint64) ([]byte, error) { signerData := authsigning.SignerData{ - Address: s.address.String(), + Address: account.address.String(), ChainID: s.ChainID(), - AccountNumber: s.accountNumber, + AccountNumber: account.accountNumber, Sequence: sequence, - PubKey: s.pk, + PubKey: account.pubKey, } bytesToSign, err := s.enc.SignModeHandler().GetSignBytes( @@ -552,7 +264,7 @@ func (s *Signer) createSignature(builder client.TxBuilder, sequence uint64) ([]b return nil, fmt.Errorf("error getting sign bytes: %w", err) } - signature, _, err := s.keys.SignByAddress(s.address, bytesToSign) + signature, _, err := s.keys.Sign(account.name, bytesToSign) if err != nil { return nil, fmt.Errorf("error signing bytes: %w", err) } @@ -561,7 +273,7 @@ func (s *Signer) createSignature(builder client.TxBuilder, sequence uint64) ([]b } // txBuilder returns the default sdk Tx builder using the celestia-app encoding config -func (s *Signer) txBuilder(opts ...TxOption) client.TxBuilder { +func (s *TxSigner) txBuilder(opts ...TxOption) client.TxBuilder { builder := s.enc.NewTxBuilder() for _, opt := range opts { builder = opt(builder) @@ -569,28 +281,7 @@ func (s *Signer) txBuilder(opts ...TxOption) client.TxBuilder { return builder } -// QueryAccount fetches the account number and sequence number from the celestia-app node. -func QueryAccount(ctx context.Context, conn *grpc.ClientConn, encCfg encoding.Config, address string) (accNum uint64, seqNum uint64, err error) { - qclient := authtypes.NewQueryClient(conn) - resp, err := qclient.Account( - ctx, - &authtypes.QueryAccountRequest{Address: address}, - ) - if err != nil { - return accNum, seqNum, err - } - - var acc authtypes.AccountI - err = encCfg.InterfaceRegistry.UnpackAny(resp.Account, &acc) - if err != nil { - return accNum, seqNum, err - } - - accNum, seqNum = acc.GetAccountNumber(), acc.GetSequence() - return accNum, seqNum, nil -} - -func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.SignatureV2 { +func (s *TxSigner) getSignatureV2(sequence uint64, pubKey cryptotypes.PubKey, signature []byte) signing.SignatureV2 { sigV2 := signing.SignatureV2{ Data: &signing.SingleSignatureData{ SignMode: signing.SignMode_SIGN_MODE_DIRECT, @@ -599,19 +290,7 @@ func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.Signa Sequence: sequence, } if sequence == 0 { - sigV2.PubKey = s.pk + sigV2.PubKey = pubKey } return sigV2 } - -func getSequenceNumber(tx authsigning.Tx) (uint64, error) { - sigs, err := tx.GetSignaturesV2() - if err != nil { - return 0, err - } - if len(sigs) > 1 { - return 0, fmt.Errorf("only a signle signature is supported, got %d", len(sigs)) - } - - return sigs[0].Sequence, nil -} diff --git a/pkg/user/tx_client.go b/pkg/user/tx_client.go new file mode 100644 index 0000000000..e7d47d4823 --- /dev/null +++ b/pkg/user/tx_client.go @@ -0,0 +1,463 @@ +package user + +import ( + "bytes" + "context" + "errors" + "fmt" + "strings" + "sync" + "time" + + "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + "github.com/cosmos/cosmos-sdk/crypto/keyring" + sdktypes "github.com/cosmos/cosmos-sdk/types" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" + abci "github.com/tendermint/tendermint/abci/types" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + tmtypes "github.com/tendermint/tendermint/types" + "google.golang.org/grpc" + + "github.com/celestiaorg/celestia-app/app/encoding" + apperrors "github.com/celestiaorg/celestia-app/app/errors" +) + +const ( + DefaultPollTime = 3 * time.Second + DefaultGasMultiplier float64 = 1.1 +) + +type Option func(s *TxClient) + +// WithGasMultiplier is a functional option allows to configure the gas multiplier. +func WithGasMultiplier(multiplier float64) Option { + return func(c *TxClient) { + c.gasMultiplier = multiplier + } +} + +func WithPollTime(time time.Duration) Option { + return func(c *TxClient) { + c.pollTime = time + } +} + +func WithDefaultAddress(address sdktypes.AccAddress) Option { + return func(c *TxClient) { + record, err := c.signer.keys.KeyByAddress(address) + if err != nil { + panic(err) + } + c.defaultAccount = record.Name + } +} + +func WithDefaultAccount(name string) Option { + return func(c *TxClient) { + if _, err := c.signer.keys.Key(name); err != nil { + panic(err) + } + c.defaultAccount = name + } +} + +// TxClient is an abstraction for building, signing, and broadcasting Celestia transactions +// It supports multiple accounts. If none is specified, it will +// try use the default account. +// TxClient is thread-safe. +type TxClient struct { + mtx sync.Mutex + signer *TxSigner + registry codectypes.InterfaceRegistry + grpc *grpc.ClientConn + // how often to poll the network for confirmation of a transaction + pollTime time.Duration + // gasMultiplier is used to increase gas limit as it is sometimes underestimated + gasMultiplier float64 + defaultAccount string + defaultAddress sdktypes.AccAddress +} + +// NewTxClient returns a new signer using the provided keyring +func NewTxClient( + signer *TxSigner, + conn *grpc.ClientConn, + registry codectypes.InterfaceRegistry, + options ...Option, +) (*TxClient, error) { + records, err := signer.keys.List() + if err != nil { + return nil, fmt.Errorf("retrieving keys: %w", err) + } + + if len(records) == 0 { + return nil, errors.New("signer must have at least one key") + } + + addr, err := records[0].GetAddress() + if err != nil { + return nil, err + } + + txClient := &TxClient{ + signer: signer, + registry: registry, + grpc: conn, + pollTime: DefaultPollTime, + gasMultiplier: DefaultGasMultiplier, + defaultAccount: records[0].Name, + defaultAddress: addr, + } + + for _, opt := range options { + opt(txClient) + } + + return txClient, nil +} + +// SetupTxClient uses the underlying grpc connection to populate the chainID, accountNumber and sequence number of all +// the accounts in the keyring. +func SetupTxClient( + ctx context.Context, + keys keyring.Keyring, + conn *grpc.ClientConn, + encCfg encoding.Config, + options ...Option, +) (*TxClient, error) { + resp, err := tmservice.NewServiceClient(conn).GetLatestBlock( + ctx, + &tmservice.GetLatestBlockRequest{}, + ) + if err != nil { + return nil, err + } + + chainID := resp.SdkBlock.Header.ChainID + appVersion := resp.SdkBlock.Header.Version.App + + records, err := keys.List() + if err != nil { + return nil, err + } + + accounts := make([]*Account, 0, len(records)) + for _, record := range records { + addr, err := record.GetAddress() + if err != nil { + return nil, err + } + accNum, seqNum, err := QueryAccountInfo(ctx, conn, encCfg.InterfaceRegistry, addr) + if err != nil { + // skip over the accounts that don't exist in state + continue + } + + accounts = append(accounts, NewAccount(record.Name, accNum, seqNum)) + } + + signer, err := NewTxSigner(keys, encCfg.TxConfig, chainID, appVersion, accounts...) + if err != nil { + return nil, fmt.Errorf("failed to create signer: %w", err) + } + + return NewTxClient(signer, conn, encCfg.InterfaceRegistry, options...) +} + +// SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. +// TxOptions may be provided to set the fee and gas limit. +func (s *TxClient) SubmitPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + resp, err := s.BroadcastPayForBlob(ctx, blobs, opts...) + if err != nil { + return resp, err + } + + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *TxClient) SubmitPayForBlobsWithAccount(ctx context.Context, account string, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + resp, err := s.BroadcastPayForBlobWithAccount(ctx, account, blobs, opts...) + if err != nil { + return resp, err + } + + return s.ConfirmTx(ctx, resp.TxHash) +} + +// BroadcastPayForBlob signs and broadcasts a transaction to pay for blobs. +// It does not confirm that the transaction has been committed on chain. +func (s *TxClient) BroadcastPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + return s.BroadcastPayForBlobWithAccount(ctx, s.defaultAccount, blobs, opts...) +} + +func (s *TxClient) BroadcastPayForBlobWithAccount(ctx context.Context, account string, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + if err := s.checkAccountLoaded(ctx, account); err != nil { + return nil, err + } + + txBytes, _, err := s.signer.CreatePayForBlobs(account, blobs, opts...) + if err != nil { + return nil, err + } + + return s.broadcastTx(ctx, txBytes, account) +} + +// SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions +// may be provided to set the fee and gas limit. +func (s *TxClient) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { + resp, err := s.BroadcastTx(ctx, msgs, opts...) + if err != nil { + return resp, err + } + + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *TxClient) BroadcastTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + account, err := s.getAccountNameFromMsgs(msgs) + if err != nil { + return nil, err + } + + if err := s.checkAccountLoaded(ctx, account); err != nil { + return nil, err + } + + tx, account, _, err := s.signer.SignTx(msgs, opts...) + if err != nil { + return nil, err + } + + txBytes, err := s.signer.EncodeTx(tx) + if err != nil { + return nil, err + } + + return s.broadcastTx(ctx, txBytes, account) +} + +func (s *TxClient) broadcastTx(ctx context.Context, txBytes []byte, signer string) (*sdktypes.TxResponse, error) { + txClient := sdktx.NewServiceClient(s.grpc) + resp, err := txClient.BroadcastTx( + ctx, + &sdktx.BroadcastTxRequest{ + Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, + TxBytes: txBytes, + }, + ) + if err != nil { + return nil, err + } + if resp.TxResponse.Code != abci.CodeTypeOK { + if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { + // query the account to update the sequence number on-chain for the account + _, seqNum, err := QueryAccountInfo(ctx, s.grpc, s.registry, s.signer.accounts[signer].address) + if err != nil { + return nil, fmt.Errorf("querying account for new sequence number: %w\noriginal tx response: %s", err, resp.TxResponse.RawLog) + } + if err := s.signer.SetSequence(signer, seqNum); err != nil { + return nil, fmt.Errorf("setting sequence: %w", err) + } + return s.retryBroadcastingTx(ctx, txBytes) + } + return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + } + + // after the transaction has been submitted, we can increment the + // sequence of the signer + if err := s.signer.IncrementSequence(signer); err != nil { + return nil, fmt.Errorf("increment sequencing: %w", err) + } + return resp.TxResponse, nil +} + +// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the +// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction +func (s *TxClient) retryBroadcastingTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) { + blobTx, isBlobTx := tmtypes.UnmarshalBlobTx(txBytes) + if isBlobTx { + txBytes = blobTx.Tx + } + tx, err := s.signer.DecodeTx(txBytes) + if err != nil { + return nil, err + } + txBuilder := s.signer.txBuilder() + if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { + return nil, err + } + if granter := tx.FeeGranter(); granter != nil { + txBuilder.SetFeeGranter(granter) + } + if payer := tx.FeePayer(); payer != nil { + txBuilder.SetFeePayer(payer) + } + if memo := tx.GetMemo(); memo != "" { + txBuilder.SetMemo(memo) + } + if fee := tx.GetFee(); fee != nil { + txBuilder.SetFeeAmount(fee) + } + if gas := tx.GetGas(); gas > 0 { + txBuilder.SetGasLimit(gas) + } + + signer, _, err := s.signer.signTransaction(txBuilder) + if err != nil { + return nil, fmt.Errorf("resigning transaction: %w", err) + } + + newTxBytes, err := s.signer.EncodeTx(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + // rewrap the blob tx if it was originally a blob tx + if isBlobTx { + newTxBytes, err = tmtypes.MarshalBlobTx(newTxBytes, blobTx.Blobs...) + if err != nil { + return nil, err + } + } + + return s.broadcastTx(ctx, newTxBytes, signer) +} + +// ConfirmTx periodically pings the provided node for the commitment of a transaction by its +// hash. It will continually loop until the context is cancelled, the tx is found or an error +// is encountered. +func (s *TxClient) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { + txClient := sdktx.NewServiceClient(s.grpc) + + pollTicker := time.NewTicker(s.pollTime) + defer pollTicker.Stop() + + for { + resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) + if err == nil { + if resp.TxResponse.Code != 0 { + return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + } + return resp.TxResponse, nil + } + // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other + // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the + // signer more information on the status of their transaction and then update the logic here + if !strings.Contains(err.Error(), "not found") { + return &sdktypes.TxResponse{}, err + } + + // Wait for the next round. + select { + case <-ctx.Done(): + return &sdktypes.TxResponse{}, ctx.Err() + case <-pollTicker.C: + } + } +} + +func (s *TxClient) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (uint64, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + txBuilder := s.signer.txBuilder(opts...) + if err := txBuilder.SetMsgs(msgs...); err != nil { + return 0, err + } + + _, _, err := s.signer.signTransaction(txBuilder) + if err != nil { + return 0, err + } + + txBytes, err := s.signer.EncodeTx(txBuilder.GetTx()) + if err != nil { + return 0, err + } + + resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ + TxBytes: txBytes, + }) + if err != nil { + return 0, err + } + + gasLimit := uint64(float64(resp.GasInfo.GasUsed) * s.gasMultiplier) + return gasLimit, nil +} + +// Account returns an account of the signer from the key name. Also returns a bool if the +// account exists. +// Thread-safe +func (s *TxClient) Account(name string) (*Account, bool) { + s.mtx.Lock() + defer s.mtx.Unlock() + acc, exists := s.signer.accounts[name] + if !exists { + return nil, false + } + return acc.Copy(), true +} + +func (s *TxClient) AccountByAddress(address sdktypes.AccAddress) *Account { + s.mtx.Lock() + defer s.mtx.Unlock() + return s.signer.AccountByAddress(address) +} + +func (s *TxClient) DefaultAddress() sdktypes.AccAddress { + return s.defaultAddress +} + +func (s *TxClient) DefaultAccountName() string { return s.defaultAccount } + +func (s *TxClient) checkAccountLoaded(ctx context.Context, account string) error { + if _, exists := s.signer.accounts[account]; exists { + return nil + } + record, err := s.signer.keys.Key(account) + if err != nil { + return fmt.Errorf("trying to find account %s on keyring: %w", account, err) + } + addr, err := record.GetAddress() + if err != nil { + return fmt.Errorf("retrieving address from keyring: %w", err) + } + accNum, sequence, err := QueryAccountInfo(ctx, s.grpc, s.registry, addr) + if err != nil { + return fmt.Errorf("querying account %s: %w", account, err) + } + return s.signer.AddAccount(NewAccount(account, accNum, sequence)) +} + +func (s *TxClient) getAccountNameFromMsgs(msgs []sdktypes.Msg) (string, error) { + var addr sdktypes.AccAddress + for _, msg := range msgs { + signers := msg.GetSigners() + if len(signers) != 1 { + return "", fmt.Errorf("only one signer per transaction supported, got %d", len(signers)) + } + if addr == nil { + addr = signers[0] + } + if !bytes.Equal(addr, signers[0]) { + return "", errors.New("not supported: got two different signers across multiple messages") + } + } + record, err := s.signer.keys.KeyByAddress(addr) + if err != nil { + return "", err + } + return record.Name, nil +} + +// Signer exposes the tx clients underlying signer +func (s *TxClient) Signer() *TxSigner { + return s.signer +} diff --git a/pkg/user/tx_client_test.go b/pkg/user/tx_client_test.go new file mode 100644 index 0000000000..07ade404da --- /dev/null +++ b/pkg/user/tx_client_test.go @@ -0,0 +1,168 @@ +package user_test + +import ( + "context" + "testing" + "time" + + "github.com/celestiaorg/celestia-app/test/util/testfactory" + sdk "github.com/cosmos/cosmos-sdk/types" + bank "github.com/cosmos/cosmos-sdk/x/bank/types" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/libs/rand" + + "github.com/celestiaorg/celestia-app/app" + "github.com/celestiaorg/celestia-app/app/encoding" + "github.com/celestiaorg/celestia-app/pkg/user" + "github.com/celestiaorg/celestia-app/test/util/blobfactory" + "github.com/celestiaorg/celestia-app/test/util/testnode" +) + +func TestTxClientTestSuite(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode.") + } + suite.Run(t, new(TxClientTestSuite)) +} + +type TxClientTestSuite struct { + suite.Suite + + ctx testnode.Context + encCfg encoding.Config + txClient *user.TxClient +} + +func (s *TxClientTestSuite) SetupSuite() { + s.encCfg = encoding.MakeConfig(app.ModuleEncodingRegisters...) + s.ctx, _, _ = testnode.NewNetwork(s.T(), testnode.DefaultConfig().WithAccounts([]string{"a"})) + _, err := s.ctx.WaitForHeight(1) + s.Require().NoError(err) + s.txClient, err = user.SetupTxClient(s.ctx.GoContext(), s.ctx.Keyring, s.ctx.GRPCClient, s.encCfg, user.WithGasMultiplier(1.2)) + s.Require().NoError(err) +} + +func (s *TxClientTestSuite) TestSubmitPayForBlob() { + t := s.T() + blobs := blobfactory.ManyRandBlobs(t, rand.NewRand(), 1e3, 1e4) + fee := user.SetFee(1e6) + gas := user.SetGasLimit(1e6) + subCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + resp, err := s.txClient.SubmitPayForBlob(subCtx, blobs, fee, gas) + require.NoError(t, err) + require.EqualValues(t, 0, resp.Code) +} + +func (s *TxClientTestSuite) TestSubmitTx() { + t := s.T() + fee := user.SetFee(1e6) + gas := user.SetGasLimit(1e6) + addr := s.txClient.DefaultAddress() + msg := bank.NewMsgSend(addr, testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) + resp, err := s.txClient.SubmitTx(s.ctx.GoContext(), []sdk.Msg{msg}, fee, gas) + require.NoError(t, err) + require.EqualValues(t, 0, resp.Code) +} + +func (s *TxClientTestSuite) TestConfirmTx() { + t := s.T() + + fee := user.SetFee(1e6) + gas := user.SetGasLimit(1e6) + + t.Run("deadline exceeded when the context times out", func(t *testing.T) { + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), time.Second) + defer cancel() + _, err := s.txClient.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") + require.Error(t, err) + require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + }) + + t.Run("should error when tx is not found", func(t *testing.T) { + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 5*time.Second) + defer cancel() + _, err := s.txClient.ConfirmTx(ctx, "not found tx") + require.Error(t, err) + }) + + t.Run("should success when tx is found immediately", func(t *testing.T) { + addr := s.txClient.DefaultAddress() + msg := bank.NewMsgSend(addr, testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) + resp, err := s.txClient.BroadcastTx(s.ctx.GoContext(), []sdk.Msg{msg}, fee, gas) + require.NoError(t, err) + require.NotNil(t, resp) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 30*time.Second) + defer cancel() + resp, err = s.txClient.ConfirmTx(ctx, resp.TxHash) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code) + }) + + t.Run("should error when tx is found with a non-zero error code", func(t *testing.T) { + balance := s.queryCurrentBalance(t) + addr := s.txClient.DefaultAddress() + // Create a msg send with out of balance, ensure this tx fails + msg := bank.NewMsgSend(addr, testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 1+balance))) + resp, err := s.txClient.BroadcastTx(s.ctx.GoContext(), []sdk.Msg{msg}, fee, gas) + require.NoError(t, err) + require.NotNil(t, resp) + resp, err = s.txClient.ConfirmTx(s.ctx.GoContext(), resp.TxHash) + require.Error(t, err) + require.NotEqual(t, abci.CodeTypeOK, resp.Code) + }) +} + +func (s *TxClientTestSuite) TestGasEstimation() { + addr := s.txClient.DefaultAddress() + msg := bank.NewMsgSend(addr, testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) + gas, err := s.txClient.EstimateGas(s.ctx.GoContext(), []sdk.Msg{msg}) + require.NoError(s.T(), err) + require.Greater(s.T(), gas, uint64(0)) +} + +// TestGasConsumption verifies that the amount deducted from a user's balance is +// based on the fee provided in the tx instead of the gas used by the tx. This +// behavior leads to poor UX because tx submitters must over-estimate the amount +// of gas that their tx will consume and they are not refunded for the excess. +func (s *TxClientTestSuite) TestGasConsumption() { + t := s.T() + + utiaToSend := int64(1) + addr := s.txClient.DefaultAddress() + msg := bank.NewMsgSend(addr, testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, utiaToSend))) + + gasPrice := int64(1) + gasLimit := uint64(1e6) + fee := uint64(1e6) // 1 TIA + // Note: gas price * gas limit = fee amount. So by setting gasLimit and fee + // to the same value, these options set a gas price of 1utia. + options := []user.TxOption{user.SetGasLimit(gasLimit), user.SetFee(fee)} + + balanceBefore := s.queryCurrentBalance(t) + resp, err := s.txClient.SubmitTx(s.ctx.GoContext(), []sdk.Msg{msg}, options...) + require.NoError(t, err) + + require.EqualValues(t, abci.CodeTypeOK, resp.Code) + balanceAfter := s.queryCurrentBalance(t) + + // verify that the amount deducted depends on the fee set in the tx. + amountDeducted := balanceBefore - balanceAfter - utiaToSend + require.Equal(t, int64(fee), amountDeducted) + + // verify that the amount deducted does not depend on the actual gas used. + gasUsedBasedDeduction := resp.GasUsed * gasPrice + require.NotEqual(t, gasUsedBasedDeduction, amountDeducted) + // The gas used based deduction should be less than the fee because the fee is 1 TIA. + require.Less(t, gasUsedBasedDeduction, int64(fee)) +} + +func (s *TxClientTestSuite) queryCurrentBalance(t *testing.T) int64 { + balanceQuery := bank.NewQueryClient(s.ctx.GRPCClient) + addr := s.txClient.DefaultAddress() + balanceResp, err := balanceQuery.AllBalances(s.ctx.GoContext(), &bank.QueryAllBalancesRequest{Address: addr.String()}) + require.NoError(t, err) + return balanceResp.Balances.AmountOf(app.BondDenom).Int64() +}