diff --git a/Makefile b/Makefile index 468374dbfe..a10b83d88f 100644 --- a/Makefile +++ b/Makefile @@ -146,7 +146,7 @@ test-race: # TODO: Remove the -skip flag once the following tests no longer contain data races. # https://github.com/celestiaorg/celestia-app/issues/1369 @echo "--> Running tests in race mode" - @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestBlobstreamRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestBlobstreamCLI|TestUpgrade|TestMaliciousTestNode|TestBigBlobSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestBlobstream|TestCLITestSuite|TestLegacyUpgrade" + @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestBlobstreamRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestBlobstreamCLI|TestUpgrade|TestMaliciousTestNode|TestBigBlobSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestBlobstream|TestCLITestSuite|TestLegacyUpgrade|TestSignerTwins|TestConcurrentTxSubmission" .PHONY: test-race ## test-bench: Run unit tests in bench mode. diff --git a/app/errors/insufficient_gas_price_test.go b/app/errors/insufficient_gas_price_test.go index 725de95819..2129dcea34 100644 --- a/app/errors/insufficient_gas_price_test.go +++ b/app/errors/insufficient_gas_price_test.go @@ -47,9 +47,7 @@ func TestInsufficientMinGasPriceIntegration(t *testing.T) { msg, err := blob.NewMsgPayForBlobs(signer.Address().String(), appconsts.LatestVersion, b) require.NoError(t, err) - tx, err := signer.CreateTx([]sdk.Msg{msg}, user.SetGasLimit(gasLimit), user.SetFeeAmount(fee)) - require.NoError(t, err) - sdkTx, err := enc.TxConfig.TxDecoder()(tx) + sdkTx, err := signer.CreateTx([]sdk.Msg{msg}, user.SetGasLimit(gasLimit), user.SetFeeAmount(fee)) require.NoError(t, err) decorator := ante.NewDeductFeeDecorator(testApp.AccountKeeper, testApp.BankKeeper, testApp.FeeGrantKeeper, nil) diff --git a/app/errors/nonce_mismatch.go b/app/errors/nonce_mismatch.go index 2726d61060..8209aac8b7 100644 --- a/app/errors/nonce_mismatch.go +++ b/app/errors/nonce_mismatch.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "strings" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ) @@ -13,6 +14,11 @@ func IsNonceMismatch(err error) bool { return errors.Is(err, sdkerrors.ErrWrongSequence) } +// IsNonceMismatch checks if the error code matches the sequence mismatch. +func IsNonceMismatchCode(code uint32) bool { + return code == sdkerrors.ErrWrongSequence.ABCICode() +} + // ParseNonceMismatch extracts the expected sequence number from the // ErrWrongSequence error. func ParseNonceMismatch(err error) (uint64, error) { @@ -20,9 +26,19 @@ func ParseNonceMismatch(err error) (uint64, error) { return 0, errors.New("error is not a sequence mismatch") } - numbers := regexpInt.FindAllString(err.Error(), -1) + return ParseExpectedSequence(err.Error()) +} + +// ParseExpectedSequence extracts the expected sequence number from the +// ErrWrongSequence error. +func ParseExpectedSequence(str string) (uint64, error) { + if !strings.HasPrefix(str, "account sequence mismatch") { + return 0, fmt.Errorf("unexpected wrong sequence error: %s", str) + } + + numbers := regexpInt.FindAllString(str, -1) if len(numbers) != 2 { - return 0, fmt.Errorf("unexpected wrong sequence error: %w", err) + return 0, fmt.Errorf("expected two numbers in string, got %d", len(numbers)) } // the first number is the expected sequence number diff --git a/app/errors/nonce_mismatch_test.go b/app/errors/nonce_mismatch_test.go index 861b4c2307..f784431c64 100644 --- a/app/errors/nonce_mismatch_test.go +++ b/app/errors/nonce_mismatch_test.go @@ -41,9 +41,7 @@ func TestNonceMismatchIntegration(t *testing.T) { msg, err := blob.NewMsgPayForBlobs(signer.Address().String(), appconsts.LatestVersion, b) require.NoError(t, err) - tx, err := signer.CreateTx([]sdk.Msg{msg}) - require.NoError(t, err) - sdkTx, err := enc.TxConfig.TxDecoder()(tx) + sdkTx, err := signer.CreateTx([]sdk.Msg{msg}) require.NoError(t, err) decorator := ante.NewSigVerificationDecorator(testApp.AccountKeeper, encCfg.TxConfig.SignModeHandler()) diff --git a/app/test/big_blob_test.go b/app/test/big_blob_test.go index 4e1072b530..4cefc87e94 100644 --- a/app/test/big_blob_test.go +++ b/app/test/big_blob_test.go @@ -7,6 +7,7 @@ import ( "github.com/celestiaorg/celestia-app/app" "github.com/celestiaorg/celestia-app/app/encoding" + "github.com/celestiaorg/celestia-app/pkg/appconsts" "github.com/celestiaorg/celestia-app/pkg/user" "github.com/celestiaorg/celestia-app/test/util/testfactory" "github.com/celestiaorg/celestia-app/test/util/testnode" @@ -77,12 +78,10 @@ func (s *BigBlobSuite) TestErrBlobsTooLarge() { for _, tc := range testCases { s.Run(tc.name, func() { - blobTx, err := signer.CreatePayForBlob([]*blob.Blob{tc.blob}, user.SetGasLimit(1e9), user.SetFee(2000000)) - require.NoError(t, err) subCtx, cancel := context.WithTimeout(s.cctx.GoContext(), 30*time.Second) defer cancel() - res, err := signer.BroadcastTx(subCtx, blobTx) - require.NoError(t, err) + res, err := signer.SubmitPayForBlob(subCtx, []*blob.Blob{tc.blob}, user.SetGasLimitAndFee(1e9, appconsts.DefaultGlobalMinGasPrice)) + require.Error(t, err) require.NotNil(t, res) require.Equal(t, tc.want, res.Code, res.Logs) }) diff --git a/app/test/priority_test.go b/app/test/priority_test.go index a8acbfc8ff..a9eeda96db 100644 --- a/app/test/priority_test.go +++ b/app/test/priority_test.go @@ -3,6 +3,7 @@ package app_test import ( "encoding/hex" "sort" + "sync" "testing" "time" @@ -71,42 +72,46 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { t := s.T() // quickly submit blobs with a random fee - hashes := make([]string, 0, len(s.signers)) + + hashes := make(chan string, len(s.signers)) + blobSize := uint32(100) + gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) + wg := &sync.WaitGroup{} for _, signer := range s.signers { - blobSize := uint32(100) - gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) - gasPrice := s.rand.Float64() - btx, err := signer.CreatePayForBlob( - blobfactory.ManyBlobs( - s.rand, - []namespace.Namespace{namespace.RandomBlobNamespace()}, - []int{100}), - user.SetGasLimitAndFee(gasLimit, gasPrice), - ) - require.NoError(t, err) - resp, err := signer.BroadcastTx(s.cctx.GoContext(), btx) - require.NoError(t, err) - require.Equal(t, abci.CodeTypeOK, resp.Code, resp.RawLog) - hashes = append(hashes, resp.TxHash) + wg.Add(1) + go func() { + defer wg.Done() + gasPrice := float64(s.rand.Intn(1000)+1) / 1000 + resp, err := signer.SubmitPayForBlob( + s.cctx.GoContext(), + blobfactory.ManyBlobs( + s.rand, + []namespace.Namespace{namespace.RandomBlobNamespace()}, + []int{100}), + user.SetGasLimitAndFee(gasLimit, gasPrice), + ) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code, resp.RawLog) + hashes <- resp.TxHash + }() } + wg.Wait() + close(hashes) + err := s.cctx.WaitForNextBlock() require.NoError(t, err) // get the responses for each tx for analysis and sort by height // note: use rpc types because they contain the tx index heightMap := make(map[int64][]*rpctypes.ResultTx) - for _, hash := range hashes { - resp, err := s.signers[0].ConfirmTx(s.cctx.GoContext(), hash) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, abci.CodeTypeOK, resp.Code) + for hash := range hashes { // use the core rpc type because it contains the tx index hash, err := hex.DecodeString(hash) require.NoError(t, err) coreRes, err := s.cctx.Client.Tx(s.cctx.GoContext(), hash, false) require.NoError(t, err) - heightMap[resp.Height] = append(heightMap[resp.Height], coreRes) + heightMap[coreRes.Height] = append(heightMap[coreRes.Height], coreRes) } require.GreaterOrEqual(t, len(heightMap), 1) @@ -123,7 +128,7 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { // check that there was at least one block with more than three transactions // in it. This is more of a sanity check than a test. - require.True(t, highestNumOfTxsPerBlock > 3) + require.Greater(t, highestNumOfTxsPerBlock, 3) } func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { @@ -135,14 +140,14 @@ func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { func isSortedByFee(t *testing.T, ecfg encoding.Config, responses []*rpctypes.ResultTx) bool { for i := 0; i < len(responses)-1; i++ { - if gasPrice(t, ecfg, responses[i]) <= gasPrice(t, ecfg, responses[i+1]) { + if getGasPrice(t, ecfg, responses[i]) <= getGasPrice(t, ecfg, responses[i+1]) { return false } } return true } -func gasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { +func getGasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { sdkTx, err := ecfg.TxConfig.TxDecoder()(resp.Tx) require.NoError(t, err) feeTx := sdkTx.(sdk.FeeTx) diff --git a/go.work.sum b/go.work.sum index 634e906ebf..b185463cf8 100644 --- a/go.work.sum +++ b/go.work.sum @@ -925,6 +925,7 @@ github.com/minio/sha256-simd v1.0.0/go.mod h1:OuYzVNI5vcoYIAmbIvHPl3N3jUzVedXbKy github.com/moricho/tparallel v0.3.0/go.mod h1:leENX2cUv7Sv2qDgdi0D0fCftN8fRC67Bcn8pqzeYNI= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J1GEMiLbxo1LJaP8RfCpH6pymGZus= github.com/nakabonne/nestif v0.3.1/go.mod h1:9EtoZochLn5iUprVDmDjqGKPofoUEBL8U4Ngq6aY7OE= github.com/nats-io/jwt/v2 v2.0.3/go.mod h1:VRP+deawSXyhNjXmxPCHskrR6Mq50BqpEI5SEcNiGlY= github.com/nats-io/nats-server/v2 v2.5.0/go.mod h1:Kj86UtrXAL6LwYRA6H4RqzkHhK0Vcv2ZnKD5WbQ1t3g= diff --git a/pkg/user/e2e_test.go b/pkg/user/e2e_test.go new file mode 100644 index 0000000000..dcc2d80549 --- /dev/null +++ b/pkg/user/e2e_test.go @@ -0,0 +1,100 @@ +package user_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/celestia-app/pkg/user" + "github.com/celestiaorg/celestia-app/test/util/blobfactory" + "github.com/celestiaorg/celestia-app/test/util/testnode" + "github.com/celestiaorg/go-square/blob" + "github.com/stretchr/testify/require" + tmrand "github.com/tendermint/tendermint/libs/rand" +) + +func TestConcurrentTxSubmission(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // Setup network + tmConfig := testnode.DefaultTendermintConfig() + tmConfig.Consensus.TimeoutCommit = 10 * time.Second + ctx, _, _ := testnode.NewNetwork(t, testnode.DefaultConfig().WithTendermintConfig(tmConfig)) + _, err := ctx.WaitForHeight(1) + require.NoError(t, err) + + // Setup signer + signer, err := testnode.NewSingleSignerFromContext(ctx) + require.NoError(t, err) + + // Pregenerate all the blobs + numTxs := 10 + blobs := blobfactory.ManyRandBlobs(tmrand.NewRand(), blobfactory.Repeat(2048, numTxs)...) + + // Prepare transactions + var ( + wg sync.WaitGroup + errCh = make(chan error) + ) + + subCtx, cancel := context.WithCancel(ctx.GoContext()) + defer cancel() + time.AfterFunc(time.Minute, cancel) + for i := 0; i < numTxs; i++ { + wg.Add(1) + go func(b *blob.Blob) { + defer wg.Done() + _, err := signer.SubmitPayForBlob(subCtx, []*blob.Blob{b}, user.SetGasLimitAndFee(500_000, appconsts.DefaultGlobalMinGasPrice)) + if err != nil && !errors.Is(err, context.Canceled) { + // only catch the first error + select { + case errCh <- err: + cancel() + default: + } + } + }(blobs[i]) + } + wg.Wait() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } +} + +func TestSignerTwins(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // Setup network + tmConfig := testnode.DefaultTendermintConfig() + tmConfig.Consensus.TimeoutCommit = 10 * time.Second + ctx, _, _ := testnode.NewNetwork(t, testnode.DefaultConfig().WithTendermintConfig(tmConfig)) + _, err := ctx.WaitForHeight(1) + require.NoError(t, err) + + signer1, err := testnode.NewSingleSignerFromContext(ctx) + require.NoError(t, err) + signer2, err := testnode.NewSingleSignerFromContext(ctx) + require.NoError(t, err) + + blobs := blobfactory.ManyRandBlobs(tmrand.NewRand(), blobfactory.Repeat(2048, 8)...) + + _, err = signer1.SubmitPayForBlob(ctx.GoContext(), blobs[:1], user.SetGasLimitAndFee(500_000, appconsts.DefaultGlobalMinGasPrice)) + require.NoError(t, err) + + _, err = signer2.SubmitPayForBlob(ctx.GoContext(), blobs[1:3], user.SetGasLimitAndFee(500_000, appconsts.DefaultGlobalMinGasPrice)) + require.NoError(t, err) + + signer1.ForceSetSequence(4) + _, err = signer1.SubmitPayForBlob(ctx.GoContext(), blobs[3:5], user.SetGasLimitAndFee(500_000, appconsts.DefaultGlobalMinGasPrice)) + require.NoError(t, err) +} diff --git a/pkg/user/signer.go b/pkg/user/signer.go index 2a4a323a30..f56e592f3e 100644 --- a/pkg/user/signer.go +++ b/pkg/user/signer.go @@ -9,6 +9,7 @@ import ( "time" "github.com/celestiaorg/celestia-app/app/encoding" + apperrors "github.com/celestiaorg/celestia-app/app/errors" blobtypes "github.com/celestiaorg/celestia-app/x/blob/types" "github.com/celestiaorg/go-square/blob" "github.com/cosmos/cosmos-sdk/client" @@ -16,10 +17,11 @@ import ( "github.com/cosmos/cosmos-sdk/crypto/keyring" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdktypes "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" "google.golang.org/grpc" ) @@ -34,12 +36,22 @@ type Signer struct { pk cryptotypes.PubKey chainID string accountNumber uint64 - appVersion uint64 - pollTime time.Duration - - mtx sync.RWMutex - lastSignedSequence uint64 - lastConfirmedSequence uint64 + // FIXME: the signer is currently incapable of detecting an appversion + // change and could produce incorrect PFBs if it the network is at an + // appVersion that the signer does not support + appVersion uint64 + + mtx sync.RWMutex + // how often to poll the network for confirmation of a transaction + pollTime time.Duration + // the signers local view of the sequence number + localSequence uint64 + // the chains last known sequence number + networkSequence uint64 + // lookup map of all pending and yet to be confirmed outbound transactions + outboundSequences map[uint64]struct{} + // a reverse map for confirming which sequence numbers have been committed + reverseTxHashSequenceMap map[string]uint64 } // NewSigner returns a new signer using the provided keyring @@ -64,17 +76,19 @@ func NewSigner( } return &Signer{ - keys: keys, - address: address, - grpc: conn, - enc: enc, - pk: pk, - chainID: chainID, - accountNumber: accountNumber, - appVersion: appVersion, - lastSignedSequence: sequence, - lastConfirmedSequence: sequence, - pollTime: DefaultPollTime, + keys: keys, + address: address, + grpc: conn, + enc: enc, + pk: pk, + chainID: chainID, + accountNumber: accountNumber, + appVersion: appVersion, + localSequence: sequence, + networkSequence: sequence, + pollTime: DefaultPollTime, + outboundSequences: make(map[uint64]struct{}), + reverseTxHashSequenceMap: make(map[string]uint64), }, nil } @@ -130,17 +144,14 @@ func SetupSigner( // SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions // may be provided to set the fee and gas limit. func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreateTx(msgs, opts...) + tx, err := s.CreateTx(msgs, opts...) if err != nil { return nil, err } - resp, err := s.BroadcastTx(ctx, txBytes) + resp, err := s.BroadcastTx(ctx, tx) if err != nil { - return nil, err - } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) + return resp, err } return s.ConfirmTx(ctx, resp.TxHash) @@ -149,25 +160,35 @@ func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOp // SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. // TxOptions may be provided to set the fee and gas limit. func (s *Signer) SubmitPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreatePayForBlob(blobs, opts...) + resp, err := s.broadcastPayForBlob(ctx, blobs, opts...) if err != nil { - return nil, err + return resp, err } - resp, err := s.BroadcastTx(ctx, txBytes) + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *Signer) broadcastPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, seqNum, err := s.createPayForBlobs(blobs, opts...) if err != nil { return nil, err } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) - } - return s.ConfirmTx(ctx, resp.TxHash) + return s.broadcastTx(ctx, txBytes, seqNum) } // CreateTx forms a transaction from the provided messages and signs it. TxOptions may be optionally // used to set the gas limit and fee. -func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) { +func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.createTx(msgs, opts...) +} + +func (s *Signer) createTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { txBuilder := s.txBuilder(opts...) if err := txBuilder.SetMsgs(msgs...); err != nil { return nil, err @@ -177,62 +198,185 @@ func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) return nil, err } - return s.enc.TxEncoder()(txBuilder.GetTx()) + return txBuilder.GetTx(), nil } func (s *Signer) CreatePayForBlob(blobs []*blob.Blob, opts ...TxOption) ([]byte, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + blobTx, _, err := s.createPayForBlobs(blobs, opts...) + return blobTx, err +} + +func (s *Signer) createPayForBlobs(blobs []*blob.Blob, opts ...TxOption) ([]byte, uint64, error) { msg, err := blobtypes.NewMsgPayForBlobs(s.address.String(), s.appVersion, blobs...) if err != nil { - return nil, err + return nil, 0, err } - txBytes, err := s.CreateTx([]sdktypes.Msg{msg}, opts...) + tx, err := s.createTx([]sdktypes.Msg{msg}, opts...) if err != nil { - return nil, err + return nil, 0, err + } + + seqNum, err := getSequenceNumber(tx) + if err != nil { + panic(err) } - return blob.MarshalBlobTx(txBytes, blobs...) + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, 0, err + } + + blobTx, err := blob.MarshalBlobTx(txBytes, blobs...) + return blobTx, seqNum, err +} + +func (s *Signer) EncodeTx(tx sdktypes.Tx) ([]byte, error) { + return s.enc.TxEncoder()(tx) +} + +func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { + tx, err := s.enc.TxDecoder()(txBytes) + if err != nil { + return nil, err + } + authTx, ok := tx.(authsigning.Tx) + if !ok { + return nil, errors.New("not an authsigning transaction") + } + return authTx, nil } // BroadcastTx submits the provided transaction bytes to the chain and returns the response. -func (s *Signer) BroadcastTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) +func (s *Signer) BroadcastTx(ctx context.Context, tx authsigning.Tx) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, err + } + sequence, err := getSequenceNumber(tx) + if err != nil { + return nil, err + } + return s.broadcastTx(ctx, txBytes, sequence) +} + +// CONTRACT: assumes the caller has the lock +func (s *Signer) broadcastTx(ctx context.Context, txBytes []byte, sequence uint64) (*sdktypes.TxResponse, error) { + if _, exists := s.outboundSequences[sequence]; exists { + return s.retryBroadcastingTx(ctx, txBytes, sequence+1) + } + + if sequence < s.networkSequence { + s.localSequence = s.networkSequence + return s.retryBroadcastingTx(ctx, txBytes, s.localSequence) + } - // TODO (@cmwaters): handle nonce mismatch errors + txClient := sdktx.NewServiceClient(s.grpc) resp, err := txClient.BroadcastTx( ctx, - &tx.BroadcastTxRequest{ - Mode: tx.BroadcastMode_BROADCAST_MODE_SYNC, + &sdktx.BroadcastTxRequest{ + Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, TxBytes: txBytes, }, ) if err != nil { return nil, err } - return resp.TxResponse, nil + if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { + // extract what the lastCommittedNonce on chain is + nextSequence, err := apperrors.ParseExpectedSequence(resp.TxResponse.RawLog) + if err != nil { + return nil, fmt.Errorf("parsing nonce mismatch upon retry: %w", err) + } + s.networkSequence = nextSequence + s.localSequence = nextSequence + return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + } + if resp.TxResponse.Code == abci.CodeTypeOK { + s.outboundSequences[sequence] = struct{}{} + s.reverseTxHashSequenceMap[resp.TxResponse.TxHash] = sequence + return resp.TxResponse, nil + } + return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) +} + +// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the +// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction +func (s *Signer) retryBroadcastingTx(ctx context.Context, txBytes []byte, newSequenceNumber uint64) (*sdktypes.TxResponse, error) { + blobTx, isBlobTx := blob.UnmarshalBlobTx(txBytes) + if isBlobTx { + txBytes = blobTx.Tx + } + tx, err := s.DecodeTx(txBytes) + if err != nil { + return nil, err + } + txBuilder := s.txBuilder() + if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { + return nil, err + } + if granter := tx.FeeGranter(); granter != nil { + txBuilder.SetFeeGranter(granter) + } + if payer := tx.FeePayer(); payer != nil { + txBuilder.SetFeePayer(payer) + } + if memo := tx.GetMemo(); memo != "" { + txBuilder.SetMemo(memo) + } + if fee := tx.GetFee(); fee != nil { + txBuilder.SetFeeAmount(fee) + } + if gas := tx.GetGas(); gas > 0 { + txBuilder.SetGasLimit(gas) + } + + if err := s.signTransaction(txBuilder, newSequenceNumber); err != nil { + return nil, fmt.Errorf("resigning transaction: %w", err) + } + + newTxBytes, err := s.EncodeTx(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + // rewrap the blob tx if it was originally a blob tx + if isBlobTx { + newTxBytes, err = blob.MarshalBlobTx(newTxBytes, blobTx.Blobs...) + if err != nil { + return nil, err + } + } + + return s.broadcastTx(ctx, newTxBytes, newSequenceNumber) } // ConfirmTx periodically pings the provided node for the commitment of a transaction by its // hash. It will continually loop until the context is cancelled, the tx is found or an error // is encountered. func (s *Signer) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) - - s.mtx.RLock() - pollTime := s.pollTime - s.mtx.RUnlock() + txClient := sdktx.NewServiceClient(s.grpc) - pollTicker := time.NewTicker(pollTime) + pollTicker := time.NewTicker(s.getPollTime()) defer pollTicker.Stop() for { - resp, err := txClient.GetTx(ctx, &tx.GetTxRequest{Hash: txHash}) + resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) if err == nil { if resp.TxResponse.Code != 0 { - return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + s.updateNetworkSequence(txHash, false) + return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) } + s.updateNetworkSequence(txHash, true) return resp.TxResponse, nil } + // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other + // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the + // signer more information on the status of their transaction and then update the logic here if !strings.Contains(err.Error(), "not found") { return &sdktypes.TxResponse{}, err } @@ -252,7 +396,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - if err := s.signTransaction(txBuilder, s.Sequence()); err != nil { + if err := s.signTransaction(txBuilder, s.LocalSequence()); err != nil { return 0, err } @@ -261,7 +405,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - resp, err := tx.NewServiceClient(s.grpc).Simulate(ctx, &tx.SimulateRequest{ + resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ TxBytes: txBytes, }) if err != nil { @@ -293,42 +437,62 @@ func (s *Signer) SetPollTime(pollTime time.Duration) { s.pollTime = pollTime } +func (s *Signer) getPollTime() time.Duration { + s.mtx.Lock() + defer s.mtx.Unlock() + return s.pollTime +} + // PubKey returns the public key of the signer func (s *Signer) PubKey() cryptotypes.PubKey { return s.pk } -// Sequence returns the last signed sequence number of the signer -func (s *Signer) Sequence() uint64 { +// LocalSequence returns the next sequence number of the signers +// locally saved +func (s *Signer) LocalSequence() uint64 { s.mtx.RLock() defer s.mtx.RUnlock() - return s.lastSignedSequence + return s.localSequence } -// GetSequence gets the latest signed sequence and increments the local sequence number -// Deprecated: Use Sequence if you want to get the latest signed sequence number -func (s *Signer) GetSequence() uint64 { - s.mtx.Lock() - defer s.mtx.Unlock() - defer func() { s.lastSignedSequence++ }() - return s.lastSignedSequence +func (s *Signer) NetworkSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.networkSequence } -// getAndIncrementSequence gets the latest signed sequence and increments the local sequence number +// getAndIncrementSequence gets the latest signed sequence and increments the +// local sequence number func (s *Signer) getAndIncrementSequence() uint64 { + defer func() { s.localSequence++ }() + return s.localSequence +} + +// ForceSetSequence manually overrides the current local and network level +// sequence number. Be careful when invoking this as it may cause the +// transactions to reject the sequence if it doesn't match the one in state +func (s *Signer) ForceSetSequence(seq uint64) { s.mtx.Lock() defer s.mtx.Unlock() - defer func() { s.lastSignedSequence++ }() - return s.lastSignedSequence + s.localSequence = seq + s.networkSequence = seq } -// ForceSetSequence manually overrides the current sequence number. Be careful when -// invoking this as it may cause the transactions to reject the sequence if -// it doesn't match the one in state -func (s *Signer) ForceSetSequence(seq uint64) { +// updateNetworkSequence is called once a transaction is confirmed +// and updates the chains last known sequence number +func (s *Signer) updateNetworkSequence(txHash string, success bool) { s.mtx.Lock() defer s.mtx.Unlock() - s.lastSignedSequence = seq + sequence, exists := s.reverseTxHashSequenceMap[txHash] + if !exists { + return + } + if success && sequence >= s.networkSequence { + s.networkSequence = sequence + 1 + } + delete(s.outboundSequences, sequence) + delete(s.reverseTxHashSequenceMap, txHash) } // Keyring exposes the signers underlying keyring @@ -436,3 +600,15 @@ func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.Signa } return sigV2 } + +func getSequenceNumber(tx authsigning.Tx) (uint64, error) { + sigs, err := tx.GetSignaturesV2() + if err != nil { + return 0, err + } + if len(sigs) > 1 { + return 0, fmt.Errorf("only a signle signature is supported, got %d", len(sigs)) + } + + return sigs[0].Sequence, nil +} diff --git a/pkg/user/signer_test.go b/pkg/user/signer_test.go index 3513d388d7..c0887f775a 100644 --- a/pkg/user/signer_test.go +++ b/pkg/user/signer_test.go @@ -13,7 +13,6 @@ import ( "github.com/celestiaorg/celestia-app/test/util/testnode" sdk "github.com/cosmos/cosmos-sdk/types" bank "github.com/cosmos/cosmos-sdk/x/bank/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" abci "github.com/tendermint/tendermint/abci/types" @@ -77,28 +76,30 @@ func (s *SignerTestSuite) TestConfirmTx() { gas := user.SetGasLimit(1e6) t.Run("deadline exceeded when the context times out", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") - assert.Error(t, err) - assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + require.Error(t, err) + require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) }) t.Run("should error when tx is not found", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 5*time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "not found tx") - assert.Error(t, err) + require.Error(t, err) }) t.Run("should success when tx is found immediately", func(t *testing.T) { msg := bank.NewMsgSend(s.signer.Address(), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) - resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.NoError(t, err) - assert.Equal(t, abci.CodeTypeOK, resp.Code) + require.NoError(t, err) + require.NotNil(t, resp) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 30*time.Second) + defer cancel() + resp, err = s.signer.ConfirmTx(ctx, resp.TxHash) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code) }) t.Run("should error when tx is found with a non-zero error code", func(t *testing.T) { @@ -106,17 +107,17 @@ func (s *SignerTestSuite) TestConfirmTx() { // Create a msg send with out of balance, ensure this tx fails msg := bank.NewMsgSend(s.signer.Address(), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 1+balance))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) + require.NoError(t, err) + require.NotNil(t, resp) resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.Error(t, err) - assert.NotEqual(t, abci.CodeTypeOK, resp.Code) + require.Error(t, err) + require.NotEqual(t, abci.CodeTypeOK, resp.Code) }) } func (s *SignerTestSuite) TestGasEstimation() { msg := bank.NewMsgSend(s.signer.Address(), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) - gas, err := s.signer.EstimateGas(context.Background(), []sdk.Msg{msg}) + gas, err := s.signer.EstimateGas(s.ctx.GoContext(), []sdk.Msg{msg}) require.NoError(s.T(), err) require.Greater(s.T(), gas, uint64(0)) } @@ -147,13 +148,13 @@ func (s *SignerTestSuite) TestGasConsumption() { // verify that the amount deducted depends on the fee set in the tx. amountDeducted := balanceBefore - balanceAfter - utiaToSend - assert.Equal(t, int64(fee), amountDeducted) + require.Equal(t, int64(fee), amountDeducted) // verify that the amount deducted does not depend on the actual gas used. gasUsedBasedDeduction := resp.GasUsed * gasPrice - assert.NotEqual(t, gasUsedBasedDeduction, amountDeducted) + require.NotEqual(t, gasUsedBasedDeduction, amountDeducted) // The gas used based deduction should be less than the fee because the fee is 1 TIA. - assert.Less(t, gasUsedBasedDeduction, int64(fee)) + require.Less(t, gasUsedBasedDeduction, int64(fee)) } func (s *SignerTestSuite) queryCurrentBalance(t *testing.T) int64 { diff --git a/test/util/blobfactory/payforblob_factory.go b/test/util/blobfactory/payforblob_factory.go index ffd76b674b..f288f3e318 100644 --- a/test/util/blobfactory/payforblob_factory.go +++ b/test/util/blobfactory/payforblob_factory.go @@ -251,7 +251,10 @@ func IndexWrappedTxWithInvalidNamespace( require.NoError(t, err) msg.Namespaces[0] = bytes.Repeat([]byte{1}, 33) // invalid namespace - rawTx, err := signer.CreateTx([]sdk.Msg{msg}, DefaultTxOpts()...) + tx, err := signer.CreateTx([]sdk.Msg{msg}, DefaultTxOpts()...) + require.NoError(t, err) + + rawTx, err := signer.EncodeTx(tx) require.NoError(t, err) cTx, err := coretypes.MarshalIndexWrapper(rawTx, index) @@ -285,7 +288,10 @@ func ComplexBlobTxWithOtherMsgs(t *testing.T, rand *tmrand.Rand, signer *user.Si pfb, blobs := RandMsgPayForBlobsWithSigner(rand, signer.Address().String(), 100, 1) msgs = append(msgs, pfb) - rawTx, err := signer.CreateTx(msgs, DefaultTxOpts()...) + tx, err := signer.CreateTx(msgs, DefaultTxOpts()...) + require.NoError(t, err) + + rawTx, err := signer.EncodeTx(tx) require.NoError(t, err) btx, err := blob.MarshalBlobTx(rawTx, blobs...) diff --git a/test/util/blobfactory/test_util.go b/test/util/blobfactory/test_util.go index 8ba9c06f50..8723762c09 100644 --- a/test/util/blobfactory/test_util.go +++ b/test/util/blobfactory/test_util.go @@ -44,7 +44,12 @@ func GenerateRawSendTx(signer *user.Signer, amount int64) []byte { addr := signer.Address() msg := banktypes.NewMsgSend(addr, addr, sdk.NewCoins(amountCoin)) - rawTx, err := signer.CreateTx([]sdk.Msg{msg}, opts...) + tx, err := signer.CreateTx([]sdk.Msg{msg}, opts...) + if err != nil { + panic(err) + } + + rawTx, err := signer.EncodeTx(tx) if err != nil { panic(err) } diff --git a/test/util/direct_tx_gen.go b/test/util/direct_tx_gen.go index fa0b84f055..bc128e6543 100644 --- a/test/util/direct_tx_gen.go +++ b/test/util/direct_tx_gen.go @@ -136,14 +136,17 @@ func RandBlobTxsWithManualSequence( SignMode: signing.SignMode_SIGN_MODE_DIRECT, Signature: []byte("invalid signature"), }, - Sequence: signer.GetSequence(), + Sequence: signer.LocalSequence(), }) require.NoError(t, err) - tx, err = cfg.TxEncoder()(builder.GetTx()) + tx = builder.GetTx() require.NoError(t, err) } - cTx, err := blob.MarshalBlobTx(tx, blobs...) + rawTx, err := signer.EncodeTx(tx) + require.NoError(t, err) + + cTx, err := blob.MarshalBlobTx(rawTx, blobs...) if err != nil { panic(err) } @@ -215,7 +218,10 @@ func SendTxWithManualSequence( msg := banktypes.NewMsgSend(fromAddr, toAddr, sdk.NewCoins(sdk.NewCoin(app.BondDenom, sdk.NewIntFromUint64(amount)))) stx, err := signer.CreateTx([]sdk.Msg{msg}, opts...) require.NoError(t, err) - return stx + + rawTx, err := signer.EncodeTx(stx) + require.NoError(t, err) + return rawTx } func getAddress(account string, kr keyring.Keyring) sdk.AccAddress { diff --git a/x/blob/types/blob_tx_test.go b/x/blob/types/blob_tx_test.go index c64be16e3b..15292c769a 100644 --- a/x/blob/types/blob_tx_test.go +++ b/x/blob/types/blob_tx_test.go @@ -127,7 +127,10 @@ func TestValidateBlobTx(t *testing.T) { msg.ShareCommitments[0] = badCommit - rawTx, err := signer.CreateTx([]sdk.Msg{msg}) + tx, err := signer.CreateTx([]sdk.Msg{msg}) + require.NoError(t, err) + + rawTx, err := signer.EncodeTx(tx) require.NoError(t, err) btx := &blob.BlobTx{