Skip to content

Commit

Permalink
Added ChainWriter unit tests for GetFeeComponents and GetTransactionS…
Browse files Browse the repository at this point in the history
…tatus
  • Loading branch information
amit-momin authored and silaslenihan committed Nov 27, 2024
1 parent 843eff7 commit 77bca0f
Show file tree
Hide file tree
Showing 4 changed files with 347 additions and 6 deletions.
6 changes: 3 additions & 3 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

type SolanaChainWriterService struct {
reader client.Reader
txm txm.Txm
txm txm.TxManager
ge fees.Estimator
config ChainWriterConfig
codecs map[string]types.Codec
Expand All @@ -46,7 +46,7 @@ type MethodConfig struct {
DebugIDLocation string
}

func NewSolanaChainWriterService(reader client.Reader, txm txm.Txm, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
func NewSolanaChainWriterService(reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
codecs, err := parseIDLCodecs(config)
if err != nil {
return nil, fmt.Errorf("failed to parse IDL codecs: %w", err)
Expand Down Expand Up @@ -275,7 +275,7 @@ var (

// GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM.
func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) {
return types.Unknown, nil
return s.txm.GetTransactionStatus(ctx, transactionID)
}

// GetFeeComponents retrieves the associated gas costs for executing a transaction.
Expand Down
334 changes: 334 additions & 0 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,334 @@
package chainwriter_test

import (
"context"
"crypto/rand"
"math/big"
"sync"
"testing"
"time"

"github.com/gagliardetto/solana-go"
"github.com/gagliardetto/solana-go/programs/system"
"github.com/gagliardetto/solana-go/rpc"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/types"
"github.com/smartcontractkit/chainlink-common/pkg/utils"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/fees"
feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/txm"
keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks"
)

func TestChainWriter_GetAddresses(t *testing.T) {}

func TestChainWriter_FilterLookupTableAddresses(t *testing.T) {}

func TestChainWriter_SubmitTransaction(t *testing.T) {}

func TestChainWriter_GetTransactionStatus(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
lggr := logger.Test(t)
cfg := config.NewDefault()
// Retain transactions after finality or error to maintain their status in memory
cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(5 * time.Second)
// Disable bumping to avoid issues with send tx mocking
cfg.Chain.FeeBumpPeriod = relayconfig.MustNewDuration(0 * time.Second)
rw := clientmocks.NewReaderWriter(t)
rw.On("GetLatestBlock", mock.Anything).Return(&rpc.GetBlockResult{}, nil).Maybe()
rw.On("SlotHeight", mock.Anything).Return(uint64(0), nil).Maybe()
loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil })
ge := feemocks.NewEstimator(t)
// mock solana keystore
keystore := keyMocks.NewSimpleKeystore(t)
keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe()

// initialize and start TXM
txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr)
require.NoError(t, txm.Start(ctx))
t.Cleanup(func() { require.NoError(t, txm.Close()) })

// initialize chain writer
cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{})
require.NoError(t, err)

computeUnitLimitDefault := fees.ComputeUnitLimit(cfg.ComputeUnitLimitDefault())

// mock signature statuses calls
statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{}
rw.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return(
func(_ context.Context, sigs []solana.Signature) (out []*rpc.SignatureStatusesResult) {
for i := range sigs {
get, exists := statuses[sigs[i]]
if !exists {
out = append(out, nil)
continue
}
out = append(out, get())
}
return out
}, nil,
)

t.Run("returns unknown with error if ID not found", func(t *testing.T) {
status, err := cw.GetTransactionStatus(ctx, uuid.NewString())
require.Error(t, err)
require.Equal(t, types.Unknown, status)
})

t.Run("returns pending when transaction is broadcasted", func(t *testing.T) {
tx, signed := getTx(t, 1, keystore)
signedTx := signed(0, true, computeUnitLimitDefault)
for _, ins := range signedTx.Message.Instructions {
if cuprice, err := fees.ParseComputeUnitPrice(ins.Data); err == nil {
t.Log("compute unit price", cuprice)
}
}
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in broadcasted state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return nil
}

txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is broadcasted
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Pending, status)
})

t.Run("returns unconfirmed when transaction is processed", func(t *testing.T) {
tx, signed := getTx(t, 2, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusProcessed}
}

txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is processed
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Unconfirmed, status)
})

t.Run("returns unconfirmed when transaction is confirmed", func(t *testing.T) {
tx, signed := getTx(t, 3, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
count := 0
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer func() { count++ }()
if count == 0 {
wg.Done()
}
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusConfirmed}
}

txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is confirmed
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Unconfirmed, status)
})

t.Run("returns finalized when transaction is finalized", func(t *testing.T) {
tx, signed := getTx(t, 4, keystore)
sig := randomSignature(t)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe()

// mock transaction in processed state
var wg sync.WaitGroup
wg.Add(1)
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
defer wg.Done()
return &rpc.SignatureStatusesResult{ConfirmationStatus: rpc.ConfirmationStatusFinalized}
}

txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is finalized
wg.Wait()
// wait for next confirm cycle to ensure transaction had enough time to update in storage
time.Sleep(cfg.ConfirmPollPeriod())

status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Finalized, status)
})

t.Run("returns failed when error encountered", func(t *testing.T) {
tx, signed := getTx(t, 5, keystore)
sig := randomSignature(t)
var wg sync.WaitGroup
wg.Add(1)
rw.On("SendTx", mock.Anything, signed(0, true, computeUnitLimitDefault)).Return(sig, nil)
rw.On("SimulateTx", mock.Anything, signed(0, true, computeUnitLimitDefault), mock.Anything).Run(func(mock.Arguments) {
wg.Done()
}).Return(&rpc.SimulateTransactionResult{
Err: "FAIL",
}, nil).Maybe()

// mock transaction in processed state
statuses[sig] = func() (out *rpc.SignatureStatusesResult) {
return nil
}

txID := uuid.NewString()
err = txm.Enqueue(ctx, uuid.NewString(), tx, &txID)
require.NoError(t, err)
// wait till transaction is finalized
wg.Wait()

status, err := cw.GetTransactionStatus(ctx, txID)
require.NoError(t, err)
require.Equal(t, types.Failed, status)
})
}

func TestChainWriter_GetFeeComponents(t *testing.T) {
t.Parallel()

ctx := tests.Context(t)
cfg := config.NewDefault()
rw := clientmocks.NewReaderWriter(t)
ge := feemocks.NewEstimator(t)
ge.On("BaseComputeUnitPrice").Return(uint64(100))
cw := setupChainWriter(t, cfg, rw, ge)
t.Run("returns valid compute unit price", func(t *testing.T) {
feeComponents, err := cw.GetFeeComponents(ctx)
require.NoError(t, err)
require.Equal(t, big.NewInt(100), feeComponents.ExecutionFee)
require.Nil(t, feeComponents.DataAvailabilityFee) // always nil for Solana
})

t.Run("fails if gas estimator not set", func(t *testing.T) {
cwNoEstimator := setupChainWriter(t, cfg, rw, nil)
_, err := cwNoEstimator.GetFeeComponents(ctx)
require.Error(t, err)
})
}

func setupChainWriter(t *testing.T, cfg *config.TOMLConfig, rw client.ReaderWriter, ge fees.Estimator) *chainwriter.SolanaChainWriterService {
ctx := tests.Context(t)
lggr := logger.Test(t)
loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return rw, nil })
// mock solana keystore
keystore := keyMocks.NewSimpleKeystore(t)
keystore.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil).Maybe()
// initialize and start TXM
txm := txm.NewTxm(uuid.NewString(), loader, nil, cfg, keystore, lggr)
require.NoError(t, txm.Start(ctx))
t.Cleanup(func() { require.NoError(t, txm.Close()) })

cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{})
require.NoError(t, err)
return cw
}

func randomSignature(t *testing.T) solana.Signature {
// make random signature
sig := make([]byte, 64)
_, err := rand.Read(sig)
require.NoError(t, err)

return solana.SignatureFromBytes(sig)
}

// create placeholder transaction and returns func for signed tx with fee
func getTx(t *testing.T, val uint64, keystore txm.SimpleKeystore) (*solana.Transaction, func(fees.ComputeUnitPrice, bool, fees.ComputeUnitLimit) *solana.Transaction) {
pubkey := solana.PublicKey{}

// create transfer tx
tx, err := solana.NewTransaction(
[]solana.Instruction{
system.NewTransferInstruction(
val,
pubkey,
pubkey,
).Build(),
},
solana.Hash{},
solana.TransactionPayer(pubkey),
)
require.NoError(t, err)

base := *tx // tx to send to txm, txm will add fee & sign

return &base, func(price fees.ComputeUnitPrice, addLimit bool, limit fees.ComputeUnitLimit) *solana.Transaction {
tx := base
// add fee parameters
require.NoError(t, fees.SetComputeUnitPrice(&tx, price))
if addLimit {
require.NoError(t, fees.SetComputeUnitLimit(&tx, limit)) // default
}

// sign tx
txMsg, err := tx.Message.MarshalBinary()
require.NoError(t, err)
sigBytes, err := keystore.Sign(tests.Context(t), pubkey.String(), txMsg)
require.NoError(t, err)
var finalSig [64]byte
copy(finalSig[:], sigBytes)
tx.Signatures = append(tx.Signatures, finalSig)
return &tx
}
}
3 changes: 2 additions & 1 deletion pkg/solana/chainwriter/lookups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/gagliardetto/solana-go/rpc"
"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/utils/tests"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/client"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
Expand Down Expand Up @@ -292,7 +293,7 @@ func TestLookupTables(t *testing.T) {

txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr)

cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, *txm, nil, chainwriter.ChainWriterConfig{})
cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, txm, nil, chainwriter.ChainWriterConfig{})

t.Run("StaticLookup table resolves properly", func(t *testing.T) {
pubKeys := createTestPubKeys(t, 8)
Expand Down
Loading

0 comments on commit 77bca0f

Please sign in to comment.