diff --git a/.mockery.yaml b/.mockery.yaml index 1df96bfec..347d69c58 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -37,6 +37,7 @@ packages: config: filename: simple_keystore.go case: underscore + TxManager: github.com/smartcontractkit/chainlink-solana/pkg/solana/logpoller: interfaces: RPCClient: diff --git a/contracts/Anchor.toml b/contracts/Anchor.toml index 78a2222ad..0f8fd5b99 100644 --- a/contracts/Anchor.toml +++ b/contracts/Anchor.toml @@ -27,6 +27,7 @@ test = "pnpm run test" [programs.localnet] access_controller = "9xi644bRR8birboDGdTiwBq3C7VEeR7VuamRYYXCubUW" -log-read-test = "J1zQwrBNBngz26jRPNWsUSZMHJwBwpkoDitXRV95LdK4" +log_read_test = "J1zQwrBNBngz26jRPNWsUSZMHJwBwpkoDitXRV95LdK4" ocr_2 = "cjg3oHmg9uuPsP8D6g29NWvhySJkdYdAo9D25PRbKXJ" # need to rename the idl to satisfy anchor.js... -store = "HEvSKofvBgfaexv23kMabbYqxasxU3mQ4ibBMEmJWHny" \ No newline at end of file +store = "HEvSKofvBgfaexv23kMabbYqxasxU3mQ4ibBMEmJWHny" +write_test = "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU" \ No newline at end of file diff --git a/contracts/Cargo.lock b/contracts/Cargo.lock index 953b2b81e..43a9fa6c6 100644 --- a/contracts/Cargo.lock +++ b/contracts/Cargo.lock @@ -2664,6 +2664,13 @@ dependencies = [ "memchr", ] +[[package]] +name = "write-test" +version = "0.1.0" +dependencies = [ + "anchor-lang", +] + [[package]] name = "zerocopy" version = "0.7.32" diff --git a/contracts/artifacts/localnet/write_test-keypair.json b/contracts/artifacts/localnet/write_test-keypair.json new file mode 100644 index 000000000..dfb18e9c4 --- /dev/null +++ b/contracts/artifacts/localnet/write_test-keypair.json @@ -0,0 +1,6 @@ +[ + 26, 39, 164, 161, 246, 97, 149, 0, 58, 187, 146, 162, 53, 35, 107, 2, 117, + 242, 83, 171, 48, 7, 63, 240, 69, 221, 239, 45, 97, 55, 112, 106, 192, 228, + 214, 205, 123, 71, 58, 23, 62, 229, 166, 213, 149, 122, 96, 145, 35, 150, 16, + 156, 247, 199, 242, 108, 173, 80, 62, 231, 39, 196, 27, 192 +] diff --git a/contracts/pnpm-lock.yaml b/contracts/pnpm-lock.yaml index 860108de1..b7cec1551 100644 --- a/contracts/pnpm-lock.yaml +++ b/contracts/pnpm-lock.yaml @@ -13,13 +13,13 @@ importers: version: link:../ts '@coral-xyz/anchor': specifier: ^0.29.0 - version: 0.29.0 + version: 0.29.0(bufferutil@4.0.8)(utf-8-validate@5.0.10) '@solana/spl-token': specifier: ^0.3.5 - version: 0.3.11(@solana/web3.js@1.92.3)(fastestsmallesttextencoderdecoder@1.0.22) + version: 0.3.11(@solana/web3.js@1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10))(bufferutil@4.0.8)(fastestsmallesttextencoderdecoder@1.0.22)(utf-8-validate@5.0.10) '@solana/web3.js': specifier: ^1.50.1 <=1.92.3 - version: 1.92.3 + version: 1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) '@types/chai': specifier: ^4.2.22 version: 4.3.12 @@ -893,11 +893,11 @@ snapshots: dependencies: regenerator-runtime: 0.14.1 - '@coral-xyz/anchor@0.29.0': + '@coral-xyz/anchor@0.29.0(bufferutil@4.0.8)(utf-8-validate@5.0.10)': dependencies: - '@coral-xyz/borsh': 0.29.0(@solana/web3.js@1.95.3) + '@coral-xyz/borsh': 0.29.0(@solana/web3.js@1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10)) '@noble/hashes': 1.5.0 - '@solana/web3.js': 1.95.3 + '@solana/web3.js': 1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) bn.js: 5.2.1 bs58: 4.0.1 buffer-layout: 1.2.2 @@ -914,9 +914,9 @@ snapshots: - encoding - utf-8-validate - '@coral-xyz/borsh@0.29.0(@solana/web3.js@1.95.3)': + '@coral-xyz/borsh@0.29.0(@solana/web3.js@1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10))': dependencies: - '@solana/web3.js': 1.95.3 + '@solana/web3.js': 1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) bn.js: 5.2.1 buffer-layout: 1.2.2 @@ -926,10 +926,10 @@ snapshots: '@noble/hashes@1.5.0': {} - '@solana/buffer-layout-utils@0.2.0': + '@solana/buffer-layout-utils@0.2.0(bufferutil@4.0.8)(utf-8-validate@5.0.10)': dependencies: '@solana/buffer-layout': 4.0.1 - '@solana/web3.js': 1.95.3 + '@solana/web3.js': 1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) bigint-buffer: 1.1.5 bignumber.js: 9.1.2 transitivePeerDependencies: @@ -963,7 +963,7 @@ snapshots: '@solana/codecs-core': 2.0.0-experimental.8618508 '@solana/codecs-numbers': 2.0.0-experimental.8618508 - '@solana/spl-token-metadata@0.1.2(@solana/web3.js@1.92.3)(fastestsmallesttextencoderdecoder@1.0.22)': + '@solana/spl-token-metadata@0.1.2(@solana/web3.js@1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10))(fastestsmallesttextencoderdecoder@1.0.22)': dependencies: '@solana/codecs-core': 2.0.0-experimental.8618508 '@solana/codecs-data-structures': 2.0.0-experimental.8618508 @@ -971,16 +971,16 @@ snapshots: '@solana/codecs-strings': 2.0.0-experimental.8618508(fastestsmallesttextencoderdecoder@1.0.22) '@solana/options': 2.0.0-experimental.8618508 '@solana/spl-type-length-value': 0.1.0 - '@solana/web3.js': 1.92.3 + '@solana/web3.js': 1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) transitivePeerDependencies: - fastestsmallesttextencoderdecoder - '@solana/spl-token@0.3.11(@solana/web3.js@1.92.3)(fastestsmallesttextencoderdecoder@1.0.22)': + '@solana/spl-token@0.3.11(@solana/web3.js@1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10))(bufferutil@4.0.8)(fastestsmallesttextencoderdecoder@1.0.22)(utf-8-validate@5.0.10)': dependencies: '@solana/buffer-layout': 4.0.1 - '@solana/buffer-layout-utils': 0.2.0 - '@solana/spl-token-metadata': 0.1.2(@solana/web3.js@1.92.3)(fastestsmallesttextencoderdecoder@1.0.22) - '@solana/web3.js': 1.92.3 + '@solana/buffer-layout-utils': 0.2.0(bufferutil@4.0.8)(utf-8-validate@5.0.10) + '@solana/spl-token-metadata': 0.1.2(@solana/web3.js@1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10))(fastestsmallesttextencoderdecoder@1.0.22) + '@solana/web3.js': 1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10) buffer: 6.0.3 transitivePeerDependencies: - bufferutil @@ -992,7 +992,7 @@ snapshots: dependencies: buffer: 6.0.3 - '@solana/web3.js@1.92.3': + '@solana/web3.js@1.92.3(bufferutil@4.0.8)(utf-8-validate@5.0.10)': dependencies: '@babel/runtime': 7.25.6 '@noble/curves': 1.6.0 @@ -1005,7 +1005,7 @@ snapshots: bs58: 4.0.1 buffer: 6.0.3 fast-stable-stringify: 1.0.0 - jayson: 4.1.2 + jayson: 4.1.2(bufferutil@4.0.8)(utf-8-validate@5.0.10) node-fetch: 2.7.0 rpc-websockets: 8.0.1 superstruct: 1.0.4 @@ -1014,7 +1014,7 @@ snapshots: - encoding - utf-8-validate - '@solana/web3.js@1.95.3': + '@solana/web3.js@1.95.3(bufferutil@4.0.8)(utf-8-validate@5.0.10)': dependencies: '@babel/runtime': 7.25.6 '@noble/curves': 1.6.0 @@ -1027,7 +1027,7 @@ snapshots: bs58: 4.0.1 buffer: 6.0.3 fast-stable-stringify: 1.0.0 - jayson: 4.1.2 + jayson: 4.1.2(bufferutil@4.0.8)(utf-8-validate@5.0.10) node-fetch: 2.7.0 rpc-websockets: 9.0.2 superstruct: 2.0.2 @@ -1185,6 +1185,7 @@ snapshots: bufferutil@4.0.8: dependencies: node-gyp-build: 4.8.2 + optional: true camelcase@6.3.0: {} @@ -1268,6 +1269,7 @@ snapshots: debug@4.3.3(supports-color@8.1.1): dependencies: ms: 2.1.2 + optionalDependencies: supports-color: 8.1.1 decamelize@4.0.0: {} @@ -1433,11 +1435,11 @@ snapshots: isexe@2.0.0: {} - isomorphic-ws@4.0.1(ws@7.5.10): + isomorphic-ws@4.0.1(ws@7.5.10(bufferutil@4.0.8)(utf-8-validate@5.0.10)): dependencies: - ws: 7.5.10 + ws: 7.5.10(bufferutil@4.0.8)(utf-8-validate@5.0.10) - jayson@4.1.2: + jayson@4.1.2(bufferutil@4.0.8)(utf-8-validate@5.0.10): dependencies: '@types/connect': 3.4.38 '@types/node': 12.20.55 @@ -1447,10 +1449,10 @@ snapshots: delay: 5.0.0 es6-promisify: 5.0.0 eyes: 0.1.8 - isomorphic-ws: 4.0.1(ws@7.5.10) + isomorphic-ws: 4.0.1(ws@7.5.10(bufferutil@4.0.8)(utf-8-validate@5.0.10)) json-stringify-safe: 5.0.1 uuid: 8.3.2 - ws: 7.5.10 + ws: 7.5.10(bufferutil@4.0.8)(utf-8-validate@5.0.10) transitivePeerDependencies: - bufferutil - utf-8-validate @@ -1767,6 +1769,7 @@ snapshots: utf-8-validate@5.0.10: dependencies: node-gyp-build: 4.8.2 + optional: true util-deprecate@1.0.2: {} @@ -1793,10 +1796,13 @@ snapshots: wrappy@1.0.2: {} - ws@7.5.10: {} + ws@7.5.10(bufferutil@4.0.8)(utf-8-validate@5.0.10): + optionalDependencies: + bufferutil: 4.0.8 + utf-8-validate: 5.0.10 ws@8.18.0(bufferutil@4.0.8)(utf-8-validate@5.0.10): - dependencies: + optionalDependencies: bufferutil: 4.0.8 utf-8-validate: 5.0.10 diff --git a/contracts/programs/write_test/Cargo.toml b/contracts/programs/write_test/Cargo.toml new file mode 100644 index 000000000..ee46888c6 --- /dev/null +++ b/contracts/programs/write_test/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "write-test" +version = "0.1.0" +description = "Created with Anchor" +edition = "2021" + +[lib] +crate-type = ["cdylib", "lib"] +name = "write_test" + +[features] +no-entrypoint = [] +no-idl = [] +no-log-ix-name = [] +cpi = ["no-entrypoint"] +default = [] + +[dependencies] +anchor-lang = "0.29.0" diff --git a/contracts/programs/write_test/Xargo.toml b/contracts/programs/write_test/Xargo.toml new file mode 100644 index 000000000..475fb71ed --- /dev/null +++ b/contracts/programs/write_test/Xargo.toml @@ -0,0 +1,2 @@ +[target.bpfel-unknown-unknown.dependencies.std] +features = [] diff --git a/contracts/programs/write_test/src/lib.rs b/contracts/programs/write_test/src/lib.rs new file mode 100644 index 000000000..8d8fa3cac --- /dev/null +++ b/contracts/programs/write_test/src/lib.rs @@ -0,0 +1,51 @@ +use anchor_lang::prelude::*; + +declare_id!("39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU"); + +#[program] +pub mod write_test { + use super::*; + + pub fn initialize(ctx: Context, lookup_table: Pubkey) -> Result<()> { + let data = &mut ctx.accounts.data_account; + data.version = 1; + data.administrator = ctx.accounts.admin.key(); + data.pending_administrator = Pubkey::default(); + data.lookup_table = lookup_table; + + Ok(()) + } +} + +#[derive(Accounts)] +pub struct Initialize<'info> { + /// PDA account, derived from seeds and created by the System Program in this instruction + #[account( + init, // Initialize the account + payer = admin, // Specify the payer + space = DataAccount::SIZE, // Specify the account size + seeds = [b"data"], // Define the PDA seeds + bump // Use the bump seed + )] + pub data_account: Account<'info, DataAccount>, + + /// Admin account that pays for PDA creation and signs the transaction + #[account(mut)] + pub admin: Signer<'info>, + + /// System Program is required for PDA creation + pub system_program: Program<'info, System>, +} + +#[account] +pub struct DataAccount { + pub version: u8, + pub administrator: Pubkey, + pub pending_administrator: Pubkey, + pub lookup_table: Pubkey, +} + +impl DataAccount { + /// The total size of the `DataAccount` struct, including the discriminator + pub const SIZE: usize = 8 + 1 + 32 * 3; // 8 bytes for discriminator + 1 byte for version + 32 bytes * 3 pubkeys +} diff --git a/pkg/solana/chain.go b/pkg/solana/chain.go index 630248aff..f8c49a4cc 100644 --- a/pkg/solana/chain.go +++ b/pkg/solana/chain.go @@ -30,6 +30,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" "github.com/smartcontractkit/chainlink-solana/pkg/solana/monitor" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) type Chain interface { @@ -576,12 +577,12 @@ func (c *chain) sendTx(ctx context.Context, from, to string, amount *big.Int, ba chainTxm := c.TxManager() err = chainTxm.Enqueue(ctx, "", tx, nil, - txm.SetComputeUnitLimit(500), // reduce from default 200K limit - should only take 450 compute units + txmutils.SetComputeUnitLimit(500), // reduce from default 200K limit - should only take 450 compute units // no fee bumping and no additional fee - makes validating balance accurate - txm.SetComputeUnitPriceMax(0), - txm.SetComputeUnitPriceMin(0), - txm.SetBaseComputeUnitPrice(0), - txm.SetFeeBumpPeriod(0), + txmutils.SetComputeUnitPriceMax(0), + txmutils.SetComputeUnitPriceMin(0), + txmutils.SetBaseComputeUnitPrice(0), + txmutils.SetFeeBumpPeriod(0), ) if err != nil { return fmt.Errorf("transaction failed: %w", err) diff --git a/pkg/solana/chainwriter/ccip_example_config.go b/pkg/solana/chainwriter/ccip_example_config.go new file mode 100644 index 000000000..89038fd6a --- /dev/null +++ b/pkg/solana/chainwriter/ccip_example_config.go @@ -0,0 +1,340 @@ +package chainwriter + +import ( + "fmt" + + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" +) + +func TestConfig() { + // Fake constant addresses for the purpose of this example. + registryAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6A" + routerProgramAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6B" + routerAccountConfigAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6C" + cpiSignerAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6D" + systemProgramAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6E" + computeBudgetProgramAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6F" + sysvarProgramAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6G" + commonAddressesLookupTable := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6H" + routerLookupTable := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6I" + userAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6J" + + executionReportSingleChainIDL := `{"name":"ExecutionReportSingleChain","type":{"kind":"struct","fields":[{"name":"source_chain_selector","type":"u64"},{"name":"message","type":{"defined":"Any2SolanaRampMessage"}},{"name":"root","type":{"array":["u8",32]}},{"name":"proofs","type":{"vec":{"array":["u8",32]}}}]}},{"name":"Any2SolanaRampMessage","type":{"kind":"struct","fields":[{"name":"header","type":{"defined":"RampMessageHeader"}},{"name":"sender","type":{"vec":"u8"}},{"name":"data","type":{"vec":"u8"}},{"name":"receiver","type":{"array":["u8",32]}},{"name":"extra_args","type":{"defined":"SolanaExtraArgs"}}]}},{"name":"RampMessageHeader","type":{"kind":"struct","fields":[{"name":"message_id","type":{"array":["u8",32]}},{"name":"source_chain_selector","type":"u64"},{"name":"dest_chain_selector","type":"u64"},{"name":"sequence_number","type":"u64"},{"name":"nonce","type":"u64"}]}},{"name":"SolanaExtraArgs","type":{"kind":"struct","fields":[{"name":"compute_units","type":"u32"},{"name":"allow_out_of_order_execution","type":"bool"}]}}` + + executeConfig := MethodConfig{ + FromAddress: userAddress, + InputModifications: commoncodec.ModifiersConfig{ + // remove merkle root since it isn't a part of the on-chain type + &commoncodec.DropModifierConfig{ + Fields: []string{"Message.ExtraArgs.MerkleRoot"}, + }, + }, + ChainSpecificName: "execute", + // LookupTables are on-chain stores of accounts. They can be used in two ways: + // 1. As a way to store a list of accounts that are all associated together (i.e. Token State registry) + // 2. To compress the transactions in a TX and reduce the size of the TX. (The traditional way) + LookupTables: LookupTables{ + // DerivedLookupTables are useful in both the ways described above. + // a. The user can configure any type of look up to get a list of lookupTables to read from. + // b. The ChainWriter reads from this lookup table and store the internal addresses in memory + // c. Later, in the []Accounts the user can specify which accounts to include in the TX with an AccountsFromLookupTable lookup. + // d. Lastly, the lookup table is used to compress the size of the transaction. + DerivedLookupTables: []DerivedLookupTable{ + { + Name: "RegistryTokenState", + // In this case, the user configured the lookup table accounts to use a PDALookup, which + // generates a list of one of more PDA accounts based on the input parameters. Specifically, + // there will be multple PDA accounts if there are multiple addresses in the message, otherwise, + // there will only be one PDA account to read from. The PDA account corresponds to the lookup table. + Accounts: PDALookups{ + Name: "RegistryTokenState", + PublicKey: AccountConstant{ + Address: registryAddress, + IsSigner: false, + IsWritable: false, + }, + // Seeds would be used if the user needed to look up addresses to use as seeds, which isn't the case here. + Seeds: []Lookup{ + AccountLookup{Location: "Message.TokenAmounts.DestTokenAddress"}, + }, + IsSigner: false, + IsWritable: false, + }, + }, + }, + // Static lookup tables are the traditional use case (point 2 above) of Lookup tables. These are lookup + // tables which contain commonly used addresses in all CCIP execute transactions. The ChainWriter reads + // these lookup tables and appends them to the transaction to reduce the size of the transaction. + StaticLookupTables: []string{ + commonAddressesLookupTable, + routerLookupTable, + }, + }, + // The Accounts field is where the user specifies which accounts to include in the transaction. Each Lookup + // resolves to one or more on-chain addresses. + Accounts: []Lookup{ + // The accounts can be of any of the following types: + // 1. Account constant + // 2. Account Lookup - Based on data from input parameters + // 3. Lookup Table content - Get all the accounts from a lookup table + // 4. PDA Account Lookup - Based on another account and a seed/s + // Nested PDA Account with seeds from: + // -> input parameters + // -> constant + // PDALookups can resolve to multiple addresses if: + // A) The PublicKey lookup resolves to multiple addresses (i.e. multiple token addresses) + // B) The Seeds or ValueSeeds resolve to multiple values + PDALookups{ + Name: "PerChainConfig", + // PublicKey is a constant account in this case, not a lookup. + PublicKey: AccountConstant{ + Address: registryAddress, + IsSigner: false, + IsWritable: false, + }, + // Similar to the RegistryTokenState above, the user is looking up PDA accounts based on the dest tokens. + Seeds: []Lookup{ + AccountLookup{Location: "Message.TokenAmounts.DestTokenAddress"}, + AccountLookup{Location: "Message.Header.DestChainSelector"}, + }, + IsSigner: false, + IsWritable: false, + }, + // Lookup Table content - Get the accounts from the derived lookup table above + AccountsFromLookupTable{ + LookupTableName: "RegistryTokenState", + IncludeIndexes: []int{}, // If left empty, all addresses will be included. Otherwise, only the specified indexes will be included. + }, + // Account Lookup - Based on data from input parameters + // In this case, the user wants to add the destination token addresses to the transaction. + // Once again, this can be one or multiple addresses. + AccountLookup{ + Name: "TokenAccount", + Location: "Message.TokenAmounts.DestTokenAddress", + IsSigner: false, + IsWritable: false, + }, + // PDA Account Lookup - Based on an account lookup and an address lookup + PDALookups{ + // In this case, the token address is the public key, and the receiver is the seed. + // Again, there could be multiple token addresses, in which case this would resolve to + // multiple PDA accounts. + Name: "ReceiverAssociatedTokenAccount", + PublicKey: AccountLookup{ + Name: "TokenAccount", + Location: "Message.TokenAmounts.DestTokenAddress", + IsSigner: false, + IsWritable: false, + }, + // The seed is the receiver address. + Seeds: []Lookup{ + AccountLookup{ + Name: "Receiver", + Location: "Message.Receiver", + IsSigner: false, + IsWritable: false, + }, + }, + }, + // Account constant + AccountConstant{ + Name: "Registry", + Address: registryAddress, + IsSigner: false, + IsWritable: false, + }, + // PDA Lookup for the RegistryTokenConfig. + PDALookups{ + Name: "RegistryTokenConfig", + // constant public key + PublicKey: AccountConstant{ + Address: registryAddress, + IsSigner: false, + IsWritable: false, + }, + // The seed, once again, is the destination token address. + Seeds: []Lookup{ + AccountLookup{Location: "Message.TokenAmounts.DestTokenAddress"}, + }, + IsSigner: false, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "RouterProgram", + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "RouterAccountConfig", + Address: routerAccountConfigAddress, + IsSigner: false, + IsWritable: false, + }, + // PDA lookup to get the Router Chain Config + PDALookups{ + Name: "RouterChainConfig", + // The public key is a constant Router address. + PublicKey: AccountConstant{ + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + Seeds: []Lookup{ + AccountLookup{Location: "Message.Header.DestChainSelector"}, + AccountLookup{Location: "Message.Header.SourceChainSelector"}, + }, + IsSigner: false, + IsWritable: false, + }, + // PDA lookup to get the Router Report Accounts. + PDALookups{ + Name: "RouterReportAccount", + // The public key is a constant Router address. + PublicKey: AccountConstant{ + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + Seeds: []Lookup{ + AccountLookup{ + // The seed is the merkle root of the report, as passed into the input params. + Location: "args.MerkleRoot", + }, + }, + IsSigner: false, + IsWritable: false, + }, + // PDA lookup to get UserNoncePerChain + PDALookups{ + Name: "UserNoncePerChain", + // The public key is a constant Router address. + PublicKey: AccountConstant{ + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + // In this case, the user configured multiple seeds. These will be used in conjunction + // with the public key to generate one or multiple PDA accounts. + Seeds: []Lookup{ + AccountLookup{Location: "Message.Receiver"}, + AccountLookup{Location: "Message.Header.DestChainSelector"}, + }, + }, + // Account constant + AccountConstant{ + Name: "CPISigner", + Address: cpiSignerAddress, + IsSigner: true, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "SystemProgram", + Address: systemProgramAddress, + IsSigner: true, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "ComputeBudgetProgram", + Address: computeBudgetProgramAddress, + IsSigner: true, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "SysvarProgram", + Address: sysvarProgramAddress, + IsSigner: true, + IsWritable: false, + }, + }, + // TBD where this will be in the report + // This will be appended to every error message + DebugIDLocation: "Message.MessageID", + } + + commitConfig := MethodConfig{ + FromAddress: userAddress, + InputModifications: nil, + ChainSpecificName: "commit", + LookupTables: LookupTables{ + StaticLookupTables: []string{ + commonAddressesLookupTable, + routerLookupTable, + }, + }, + Accounts: []Lookup{ + // Account constant + AccountConstant{ + Name: "RouterProgram", + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "RouterAccountConfig", + Address: routerAccountConfigAddress, + IsSigner: false, + IsWritable: false, + }, + // PDA lookup to get the Router Report Accounts. + PDALookups{ + Name: "RouterReportAccount", + // The public key is a constant Router address. + PublicKey: AccountConstant{ + Address: routerProgramAddress, + IsSigner: false, + IsWritable: false, + }, + Seeds: []Lookup{ + AccountLookup{ + // The seed is the merkle root of the report, as passed into the input params. + Location: "args.MerkleRoots", + }, + }, + IsSigner: false, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "SystemProgram", + Address: systemProgramAddress, + IsSigner: true, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "ComputeBudgetProgram", + Address: computeBudgetProgramAddress, + IsSigner: true, + IsWritable: false, + }, + // Account constant + AccountConstant{ + Name: "SysvarProgram", + Address: sysvarProgramAddress, + IsSigner: true, + IsWritable: false, + }, + }, + DebugIDLocation: "", + } + + chainWriterConfig := ChainWriterConfig{ + Programs: map[string]ProgramConfig{ + "ccip-router": { + Methods: map[string]MethodConfig{ + "execute": executeConfig, + "commit": commitConfig, + }, + IDL: executionReportSingleChainIDL, + }, + }, + } + fmt.Println(chainWriterConfig) +} diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go new file mode 100644 index 000000000..4fcc5caa0 --- /dev/null +++ b/pkg/solana/chainwriter/chain_writer.go @@ -0,0 +1,312 @@ +package chainwriter + +import ( + "context" + "encoding/json" + "fmt" + "math/big" + + "github.com/gagliardetto/solana-go" + + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" + "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" +) + +type SolanaChainWriterService struct { + reader client.Reader + txm txm.TxManager + ge fees.Estimator + config ChainWriterConfig + codecs map[string]types.Codec +} + +// nolint // ignoring naming suggestion +type ChainWriterConfig struct { + Programs map[string]ProgramConfig +} + +type ProgramConfig struct { + Methods map[string]MethodConfig + IDL string +} + +type MethodConfig struct { + FromAddress string + InputModifications commoncodec.ModifiersConfig + ChainSpecificName string + LookupTables LookupTables + Accounts []Lookup + // Location in the args where the debug ID is stored + DebugIDLocation string +} + +func NewSolanaChainWriterService(reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { + codecs, err := parseIDLCodecs(config) + if err != nil { + return nil, fmt.Errorf("failed to parse IDL codecs: %w", err) + } + + return &SolanaChainWriterService{ + reader: reader, + txm: txm, + ge: ge, + config: config, + codecs: codecs, + }, nil +} + +func parseIDLCodecs(config ChainWriterConfig) (map[string]types.Codec, error) { + codecs := make(map[string]types.Codec) + for program, programConfig := range config.Programs { + var idl codec.IDL + if err := json.Unmarshal([]byte(programConfig.IDL), &idl); err != nil { + return nil, fmt.Errorf("failed to unmarshal IDL: %w", err) + } + idlCodec, err := codec.NewIDLInstructionsCodec(idl, binary.LittleEndian()) + if err != nil { + return nil, fmt.Errorf("failed to create codec from IDL: %w", err) + } + for method, methodConfig := range programConfig.Methods { + if methodConfig.InputModifications != nil { + modConfig, err := methodConfig.InputModifications.ToModifier(codec.DecoderHooks...) + if err != nil { + return nil, fmt.Errorf("failed to create input modifications: %w", err) + } + // add mods to codec + idlCodec, err = codec.NewNamedModifierCodec(idlCodec, method, modConfig) + if err != nil { + return nil, fmt.Errorf("failed to create named codec: %w", err) + } + } + } + codecs[program] = idlCodec + } + return codecs, nil +} + +/* +GetAddresses resolves account addresses from various `Lookup` configurations to build the required `solana.AccountMeta` list +for Solana transactions. It handles constant addresses, dynamic lookups, program-derived addresses (PDAs), and lookup tables. + +### Parameters: +- `ctx`: Context for request lifecycle management. +- `args`: Input arguments used for dynamic lookups. +- `accounts`: List of `Lookup` configurations specifying how addresses are derived. +- `derivedTableMap`: Map of pre-loaded lookup table addresses. +- `debugID`: Debug identifier for tracing errors. + +### Return: +- A slice of `solana.AccountMeta` containing derived addresses and associated metadata. + +### Account Types: +1. **AccountConstant**: + - A fixed address, provided in Base58 format, converted into a `solana.PublicKey`. + - Example: A pre-defined fee payer or system account. + +2. **AccountLookup**: + - Dynamically derived from input args using a specified location path (e.g., `user.walletAddress`). + - If the lookup table is pre-loaded, the address is fetched from `derivedTableMap`. + +3. **PDALookups**: + - Generates Program Derived Addresses (PDA) by combining a derived public key with one or more seeds. + - Seeds can be `AddressSeeds` (public keys from the input args) or `ValueSeeds` (byte arrays). + - Ensures there is only one public key if multiple seeds are provided. + +### Error Handling: +- Errors are wrapped with the `debugID` for easier tracing. +*/ +// GetAddresses resolves account addresses from various `Lookup` configurations to build the required `solana.AccountMeta` list +// for Solana transactions. +func GetAddresses(ctx context.Context, args any, accounts []Lookup, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) { + var addresses []*solana.AccountMeta + for _, accountConfig := range accounts { + meta, err := accountConfig.Resolve(ctx, args, derivedTableMap, reader) + if err != nil { + return nil, err + } + addresses = append(addresses, meta...) + } + return addresses, nil +} + +func (s *SolanaChainWriterService) FilterLookupTableAddresses( + accounts []*solana.AccountMeta, + derivedTableMap map[string]map[string][]*solana.AccountMeta, + staticTableMap map[solana.PublicKey]solana.PublicKeySlice, +) map[solana.PublicKey]solana.PublicKeySlice { + filteredLookupTables := make(map[solana.PublicKey]solana.PublicKeySlice) + + // Build a hash set of account public keys for fast lookup + usedAccounts := make(map[string]struct{}) + for _, account := range accounts { + usedAccounts[account.PublicKey.String()] = struct{}{} + } + + // Filter derived lookup tables + for _, innerMap := range derivedTableMap { + for innerIdentifier, metas := range innerMap { + tableKey, err := solana.PublicKeyFromBase58(innerIdentifier) + if err != nil { + continue + } + + // Collect public keys that are actually used + var usedAddresses solana.PublicKeySlice + for _, meta := range metas { + if _, exists := usedAccounts[meta.PublicKey.String()]; exists { + usedAddresses = append(usedAddresses, meta.PublicKey) + } + } + + // Add to the filtered map if there are any used addresses + if len(usedAddresses) > 0 { + filteredLookupTables[tableKey] = usedAddresses + } + } + } + + // Filter static lookup tables + for tableKey, addresses := range staticTableMap { + var usedAddresses solana.PublicKeySlice + for _, staticAddress := range addresses { + if _, exists := usedAccounts[staticAddress.String()]; exists { + usedAddresses = append(usedAddresses, staticAddress) + } + } + + // Add to the filtered map if there are any used addresses + if len(usedAddresses) > 0 { + filteredLookupTables[tableKey] = usedAddresses + } + } + + return filteredLookupTables +} + +func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contractName, method string, args any, transactionID string, toAddress string, meta *types.TxMeta, value *big.Int) error { + programConfig, exists := s.config.Programs[contractName] + if !exists { + return fmt.Errorf("failed to find program config for contract name: %s", contractName) + } + methodConfig, exists := programConfig.Methods[method] + if !exists { + return fmt.Errorf("failed to find method config for method: %s", method) + } + + // Configure debug ID + debugID := "" + if methodConfig.DebugIDLocation != "" { + var err error + debugID, err = GetDebugIDAtLocation(args, methodConfig.DebugIDLocation) + if err != nil { + return errorWithDebugID(fmt.Errorf("error getting debug ID from input args: %w", err), debugID) + } + } + + codec := s.codecs[contractName] + encodedPayload, err := codec.Encode(ctx, args, method) + if err != nil { + return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) + } + + // Fetch derived and static table maps + derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables) + if err != nil { + return errorWithDebugID(fmt.Errorf("error getting lookup tables: %w", err), debugID) + } + + // Resolve account metas + accounts, err := GetAddresses(ctx, args, methodConfig.Accounts, derivedTableMap, s.reader) + if err != nil { + return errorWithDebugID(fmt.Errorf("error resolving account addresses: %w", err), debugID) + } + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + + // Fetch latest blockhash + blockhash, err := s.reader.LatestBlockhash(ctx) + if err != nil { + return errorWithDebugID(fmt.Errorf("error fetching latest blockhash: %w", err), debugID) + } + + // Prepare transaction + programID, err := solana.PublicKeyFromBase58(contractName) + if err != nil { + return errorWithDebugID(fmt.Errorf("error parsing program ID: %w", err), debugID) + } + + feePayer, err := solana.PublicKeyFromBase58(methodConfig.FromAddress) + if err != nil { + return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID) + } + + tx, err := solana.NewTransaction( + []solana.Instruction{ + solana.NewInstruction(programID, accounts, encodedPayload), + }, + blockhash.Value.Blockhash, + solana.TransactionPayer(feePayer), + solana.TransactionAddressTables(filteredLookupTableMap), + ) + if err != nil { + return errorWithDebugID(fmt.Errorf("error constructing transaction: %w", err), debugID) + } + + // Enqueue transaction + if err = s.txm.Enqueue(ctx, accounts[0].PublicKey.String(), tx, &transactionID); err != nil { + return errorWithDebugID(fmt.Errorf("error enqueuing transaction: %w", err), debugID) + } + + return nil +} + +var ( + _ services.Service = &SolanaChainWriterService{} + _ types.ContractWriter = &SolanaChainWriterService{} +) + +// GetTransactionStatus returns the current status of a transaction in the underlying chain's TXM. +func (s *SolanaChainWriterService) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { + return s.txm.GetTransactionStatus(ctx, transactionID) +} + +// GetFeeComponents retrieves the associated gas costs for executing a transaction. +func (s *SolanaChainWriterService) GetFeeComponents(ctx context.Context) (*types.ChainFeeComponents, error) { + if s.ge == nil { + return nil, fmt.Errorf("gas estimator not available") + } + + fee := s.ge.BaseComputeUnitPrice() + return &types.ChainFeeComponents{ + ExecutionFee: new(big.Int).SetUint64(fee), + DataAvailabilityFee: nil, + }, nil +} + +func (s *SolanaChainWriterService) Start(context.Context) error { + return nil +} + +func (s *SolanaChainWriterService) Close() error { + return nil +} + +func (s *SolanaChainWriterService) HealthReport() map[string]error { + return nil +} + +func (s *SolanaChainWriterService) Name() string { + return "" +} + +func (s *SolanaChainWriterService) Ready() error { + return nil +} diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go new file mode 100644 index 000000000..d931fb6d8 --- /dev/null +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -0,0 +1,690 @@ +package chainwriter_test + +import ( + "bytes" + "errors" + "math/big" + "reflect" + "testing" + + ag_binary "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" + "github.com/gagliardetto/solana-go/rpc" + "github.com/google/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" + feemocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees/mocks" + txmMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" +) + +var writeTestIdlJSON = `{"version": "0.1.0","name": "write_test","instructions": [{"name": "initialize","accounts": [{"name": "dataAccount","isMut": true,"isSigner": false,"docs": ["PDA account, derived from seeds and created by the System Program in this instruction"]},{"name": "admin","isMut": true,"isSigner": true,"docs": ["Admin account that pays for PDA creation and signs the transaction"]},{"name": "systemProgram","isMut": false,"isSigner": false,"docs": ["System Program is required for PDA creation"]}],"args": [{"name": "lookupTable","type": "publicKey"}]}],"accounts": [{"name": "DataAccount","type": {"kind": "struct","fields": [{"name": "version","type": "u8"},{"name": "administrator","type": "publicKey"},{"name": "pendingAdministrator","type": "publicKey"},{"name": "lookupTable","type": "publicKey"}]}}]}` + +func TestChainWriter_GetAddresses(t *testing.T) { + ctx := tests.Context(t) + + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + // expected account meta for constant account + constantAccountMeta := &solana.AccountMeta{ + IsSigner: true, + IsWritable: true, + } + + // expected account meta for account lookup + accountLookupMeta := &solana.AccountMeta{ + IsSigner: true, + IsWritable: false, + } + + // setup pda account address + seed1 := []byte("seed1") + pda1 := mustFindPdaProgramAddress(t, [][]byte{seed1}, solana.SystemProgramID) + // expected account meta for pda lookup + pdaLookupMeta := &solana.AccountMeta{ + PublicKey: pda1, + IsSigner: false, + IsWritable: false, + } + + // setup pda account with inner field lookup + programID := chainwriter.GetRandomPubKey(t) + seed2 := []byte("seed2") + pda2 := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID) + // mock data account response from program + lookupTablePubkey := mockDataAccountLookupTable(t, rw, pda2) + // mock fetch lookup table addresses call + storedPubKeys := chainwriter.CreateTestPubKeys(t, 3) + mockFetchLookupTableAddresses(t, rw, lookupTablePubkey, storedPubKeys) + // expected account meta for derived table lookup + derivedTablePdaLookupMeta := &solana.AccountMeta{ + IsSigner: false, + IsWritable: true, + } + + lookupTableConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: derivedTablePdaLookupMeta.IsSigner, + IsWritable: derivedTablePdaLookupMeta.IsWritable, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: nil, + } + + t.Run("resolve addresses from different types of lookups", func(t *testing.T) { + constantAccountMeta.PublicKey = chainwriter.GetRandomPubKey(t) + accountLookupMeta.PublicKey = chainwriter.GetRandomPubKey(t) + // correlates to DerivedTable index in account lookup config + derivedTablePdaLookupMeta.PublicKey = storedPubKeys[0] + + args := map[string]interface{}{ + "lookup_table": accountLookupMeta.PublicKey.Bytes(), + "seed1": seed1, + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: constantAccountMeta.PublicKey.String(), + IsSigner: constantAccountMeta.IsSigner, + IsWritable: constantAccountMeta.IsWritable, + }, + chainwriter.AccountLookup{ + Name: "LookupTable", + Location: "lookup_table", + IsSigner: accountLookupMeta.IsSigner, + IsWritable: accountLookupMeta.IsWritable, + }, + chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed1 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: pdaLookupMeta.IsSigner, + IsWritable: pdaLookupMeta.IsWritable, + // Just get the address of the account, nothing internal. + InternalField: chainwriter.InternalField{}, + }, + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // account metas should be returned in the same order as the provided account lookup configs + require.Len(t, accounts, 4) + + // Validate account constant + require.Equal(t, constantAccountMeta.PublicKey, accounts[0].PublicKey) + require.Equal(t, constantAccountMeta.IsSigner, accounts[0].IsSigner) + require.Equal(t, constantAccountMeta.IsWritable, accounts[0].IsWritable) + + // Validate account lookup + require.Equal(t, accountLookupMeta.PublicKey, accounts[1].PublicKey) + require.Equal(t, accountLookupMeta.IsSigner, accounts[1].IsSigner) + require.Equal(t, accountLookupMeta.IsWritable, accounts[1].IsWritable) + + // Validate pda lookup + require.Equal(t, pdaLookupMeta.PublicKey, accounts[2].PublicKey) + require.Equal(t, pdaLookupMeta.IsSigner, accounts[2].IsSigner) + require.Equal(t, pdaLookupMeta.IsWritable, accounts[2].IsWritable) + + // Validate pda lookup with inner field from derived table + require.Equal(t, derivedTablePdaLookupMeta.PublicKey, accounts[3].PublicKey) + require.Equal(t, derivedTablePdaLookupMeta.IsSigner, accounts[3].IsSigner) + require.Equal(t, derivedTablePdaLookupMeta.IsWritable, accounts[3].IsWritable) + }) + + t.Run("resolve addresses for multiple indices from derived lookup table", func(t *testing.T) { + args := map[string]interface{}{ + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0, 2}, + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + require.Len(t, accounts, 2) + require.Equal(t, storedPubKeys[0], accounts[0].PublicKey) + require.Equal(t, storedPubKeys[2], accounts[1].PublicKey) + }) + + t.Run("resolve all addresses from derived lookup table if indices not specified", func(t *testing.T) { + args := map[string]interface{}{ + "seed2": seed2, + } + + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + }, + } + + // Fetch derived table map + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + require.Len(t, accounts, 3) + for i, storedPubkey := range storedPubKeys { + require.Equal(t, storedPubkey, accounts[i].PublicKey) + } + }) +} + +func TestChainWriter_FilterLookupTableAddresses(t *testing.T) { + ctx := tests.Context(t) + + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + programID := chainwriter.GetRandomPubKey(t) + seed1 := []byte("seed1") + pda1 := mustFindPdaProgramAddress(t, [][]byte{seed1}, programID) + // mock data account response from program + lookupTablePubkey := mockDataAccountLookupTable(t, rw, pda1) + // mock fetch lookup table addresses call + storedPubKey := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, lookupTablePubkey, []solana.PublicKey{storedPubKey}) + + unusedProgramID := chainwriter.GetRandomPubKey(t) + seed2 := []byte("seed2") + unusedPda := mustFindPdaProgramAddress(t, [][]byte{seed2}, unusedProgramID) + // mock data account response from program + unusedLookupTable := mockDataAccountLookupTable(t, rw, unusedPda) + // mock fetch lookup table addresses call + unusedKeys := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, unusedLookupTable, []solana.PublicKey{unusedKeys}) + + // mock static lookup table calls + staticLookupTablePubkey1 := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey1, chainwriter.CreateTestPubKeys(t, 2)) + staticLookupTablePubkey2 := chainwriter.GetRandomPubKey(t) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey2, chainwriter.CreateTestPubKeys(t, 2)) + + lookupTableConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: true, + IsWritable: true, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + { + Name: "MiscDerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "MiscPDA", + PublicKey: chainwriter.AccountConstant{Name: "UnusedAccount", Address: unusedProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: true, + IsWritable: true, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: []string{staticLookupTablePubkey1.String(), staticLookupTablePubkey2.String()}, + } + + args := map[string]interface{}{ + "seed1": seed1, + "seed2": seed2, + } + + t.Run("returns filtered map with only relevant addresses required by account lookup config", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + } + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + + // Filter map should only contain the address for the DerivedTable lookup defined in the account lookup config + require.Len(t, filteredLookupTableMap, len(accounts)) + entry, exists := filteredLookupTableMap[lookupTablePubkey] + require.True(t, exists) + require.Len(t, entry, 1) + require.Equal(t, storedPubKey, entry[0]) + }) + + t.Run("returns empty map if empty account lookup config provided", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{} + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + require.Empty(t, filteredLookupTableMap) + }) + + t.Run("returns empty map if only constant account lookup required", func(t *testing.T) { + accountLookupConfig := []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: chainwriter.GetRandomPubKey(t).String(), + IsSigner: false, + IsWritable: false, + }, + } + + // Fetch derived table map + derivedTableMap, staticTableMap, err := cw.ResolveLookupTables(ctx, args, lookupTableConfig) + require.NoError(t, err) + + // Resolve account metas + accounts, err := chainwriter.GetAddresses(ctx, args, accountLookupConfig, derivedTableMap, rw) + require.NoError(t, err) + + // Filter the lookup table addresses based on which accounts are actually used + filteredLookupTableMap := cw.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + require.Empty(t, filteredLookupTableMap) + }) +} + +func TestChainWriter_SubmitTransaction(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + // mock client + rw := clientmocks.NewReaderWriter(t) + // mock estimator + ge := feemocks.NewEstimator(t) + // mock txm + txm := txmMocks.NewTxManager(t) + + // setup admin key + adminPk, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + admin := adminPk.PublicKey() + + account1 := chainwriter.GetRandomPubKey(t) + account2 := chainwriter.GetRandomPubKey(t) + + seed1 := []byte("seed1") + account3 := mustFindPdaProgramAddress(t, [][]byte{seed1}, solana.SystemProgramID) + + // create lookup table addresses + seed2 := []byte("seed2") + programID := chainwriter.GetRandomPubKey(t) + derivedTablePda := mustFindPdaProgramAddress(t, [][]byte{seed2}, programID) + // mock data account response from program + derivedLookupTablePubkey := mockDataAccountLookupTable(t, rw, derivedTablePda) + // mock fetch lookup table addresses call + derivedLookupKeys := chainwriter.CreateTestPubKeys(t, 1) + mockFetchLookupTableAddresses(t, rw, derivedLookupTablePubkey, derivedLookupKeys) + + // mock static lookup table call + staticLookupTablePubkey := chainwriter.GetRandomPubKey(t) + staticLookupKeys := chainwriter.CreateTestPubKeys(t, 2) + mockFetchLookupTableAddresses(t, rw, staticLookupTablePubkey, staticLookupKeys) + + cwConfig := chainwriter.ChainWriterConfig{ + Programs: map[string]chainwriter.ProgramConfig{ + "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU": { + Methods: map[string]chainwriter.MethodConfig{ + "initialize": { + FromAddress: admin.String(), + ChainSpecificName: "initialize", + LookupTables: chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed2 for PDA lookup + chainwriter.AccountLookup{Name: "seed2", Location: "seed2"}, + }, + IsSigner: false, + IsWritable: false, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: []string{staticLookupTablePubkey.String()}, + }, + Accounts: []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "Constant", + Address: account1.String(), + IsSigner: false, + IsWritable: false, + }, + chainwriter.AccountLookup{ + Name: "LookupTable", + Location: "lookup_table", + IsSigner: false, + IsWritable: false, + }, + chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: solana.SystemProgramID.String()}, + Seeds: []chainwriter.Lookup{ + // extract seed1 for PDA lookup + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: false, + IsWritable: false, + // Just get the address of the account, nothing internal. + InternalField: chainwriter.InternalField{}, + }, + chainwriter.AccountsFromLookupTable{ + LookupTableName: "DerivedTable", + IncludeIndexes: []int{0}, + }, + }, + }, + }, + IDL: writeTestIdlJSON, + }, + }, + } + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, cwConfig) + require.NoError(t, err) + + t.Run("fails with invalid ABI", func(t *testing.T) { + invalidCWConfig := chainwriter.ChainWriterConfig{ + Programs: map[string]chainwriter.ProgramConfig{ + "write_test": { + Methods: map[string]chainwriter.MethodConfig{ + "invalid": { + ChainSpecificName: "invalid", + }, + }, + IDL: "", + }, + }, + } + + _, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, invalidCWConfig) + require.Error(t, err) + }) + + t.Run("fails to encode payload if args with missing values provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "initialize", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("fails if invalid contract name provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "write_test", "initialize", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("fails if invalid method provided", func(t *testing.T) { + txID := uuid.NewString() + args := map[string]interface{}{} + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "badMethod", args, txID, programID.String(), nil, nil) + require.Error(t, submitErr) + }) + + t.Run("submits transaction successfully", func(t *testing.T) { + recentBlockHash := solana.Hash{} + rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once() + txID := uuid.NewString() + configProgramID := solana.MustPublicKeyFromBase58("39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU") + + txm.On("Enqueue", mock.Anything, account1.String(), mock.MatchedBy(func(tx *solana.Transaction) bool { + // match transaction fields to ensure it was built as expected + require.Equal(t, recentBlockHash, tx.Message.RecentBlockhash) + require.Len(t, tx.Message.Instructions, 1) + require.Len(t, tx.Message.AccountKeys, 5) // fee payer + derived accounts + require.Equal(t, admin, tx.Message.AccountKeys[0]) // fee payer + require.Equal(t, account1, tx.Message.AccountKeys[1]) // account constant + require.Equal(t, account2, tx.Message.AccountKeys[2]) // account lookup + require.Equal(t, account3, tx.Message.AccountKeys[3]) // pda lookup + require.Equal(t, configProgramID, tx.Message.AccountKeys[4]) // instruction program ID + require.Len(t, tx.Message.AddressTableLookups, 1) // address table look contains entry + require.Equal(t, derivedLookupTablePubkey, tx.Message.AddressTableLookups[0].AccountKey) // address table + return true + }), &txID).Return(nil).Once() + + args := map[string]interface{}{ + "lookupTable": chainwriter.GetRandomPubKey(t).Bytes(), + "lookup_table": account2.Bytes(), + "seed1": seed1, + "seed2": seed2, + } + submitErr := cw.SubmitTransaction(ctx, "39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU", "initialize", args, txID, programID.String(), nil, nil) + require.NoError(t, submitErr) + }) +} + +func TestChainWriter_GetTransactionStatus(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + rw := clientmocks.NewReaderWriter(t) + ge := feemocks.NewEstimator(t) + + // mock txm + txm := txmMocks.NewTxManager(t) + + // initialize chain writer + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + t.Run("returns unknown with error if ID not found", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unknown, errors.New("tx not found")).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.Error(t, err) + require.Equal(t, types.Unknown, status) + }) + + t.Run("returns pending when transaction is pending", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Pending, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Pending, status) + }) + + t.Run("returns unconfirmed when transaction is unconfirmed", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Unconfirmed, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Unconfirmed, status) + }) + + t.Run("returns finalized when transaction is finalized", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Finalized, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + }) + + t.Run("returns failed when transaction error classfied as failed", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Failed, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + }) + + t.Run("returns fatal when transaction error classfied as fatal", func(t *testing.T) { + txID := uuid.NewString() + txm.On("GetTransactionStatus", mock.Anything, txID).Return(types.Fatal, nil).Once() + status, err := cw.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Fatal, status) + }) +} + +func TestChainWriter_GetFeeComponents(t *testing.T) { + t.Parallel() + + ctx := tests.Context(t) + rw := clientmocks.NewReaderWriter(t) + ge := feemocks.NewEstimator(t) + ge.On("BaseComputeUnitPrice").Return(uint64(100)) + + // mock txm + txm := txmMocks.NewTxManager(t) + + cw, err := chainwriter.NewSolanaChainWriterService(rw, txm, ge, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + + t.Run("returns valid compute unit price", func(t *testing.T) { + feeComponents, err := cw.GetFeeComponents(ctx) + require.NoError(t, err) + require.Equal(t, big.NewInt(100), feeComponents.ExecutionFee) + require.Nil(t, feeComponents.DataAvailabilityFee) // always nil for Solana + }) + + t.Run("fails if gas estimator not set", func(t *testing.T) { + cwNoEstimator, err := chainwriter.NewSolanaChainWriterService(rw, txm, nil, chainwriter.ChainWriterConfig{}) + require.NoError(t, err) + _, err = cwNoEstimator.GetFeeComponents(ctx) + require.Error(t, err) + }) +} + +func mustBorshEncodeStruct(t *testing.T, data interface{}) []byte { + buf := new(bytes.Buffer) + err := ag_binary.NewBorshEncoder(buf).Encode(data) + require.NoError(t, err) + return buf.Bytes() +} + +func mustFindPdaProgramAddress(t *testing.T, seeds [][]byte, programID solana.PublicKey) solana.PublicKey { + pda, _, err := solana.FindProgramAddress(seeds, programID) + require.NoError(t, err) + return pda +} + +func mockDataAccountLookupTable(t *testing.T, rw *clientmocks.ReaderWriter, pda solana.PublicKey) solana.PublicKey { + lookupTablePubkey := chainwriter.GetRandomPubKey(t) + dataAccount := DataAccount{ + Discriminator: [8]byte{}, + Version: 1, + Administrator: chainwriter.GetRandomPubKey(t), + PendingAdministrator: chainwriter.GetRandomPubKey(t), + LookupTable: lookupTablePubkey, + } + dataAccountBytes := mustBorshEncodeStruct(t, dataAccount) + rw.On("GetAccountInfoWithOpts", mock.Anything, pda, mock.Anything).Return(&rpc.GetAccountInfoResult{ + RPCContext: rpc.RPCContext{}, + Value: &rpc.Account{Data: rpc.DataBytesOrJSONFromBytes(dataAccountBytes)}, + }, nil) + return lookupTablePubkey +} + +func mockFetchLookupTableAddresses(t *testing.T, rw *clientmocks.ReaderWriter, lookupTablePubkey solana.PublicKey, storedPubkeys []solana.PublicKey) { + var lookupTablePubkeySlice solana.PublicKeySlice + lookupTablePubkeySlice.Append(storedPubkeys...) + lookupTableState := addresslookuptable.AddressLookupTableState{ + Addresses: lookupTablePubkeySlice, + } + lookupTableStateBytes := mustBorshEncodeStruct(t, lookupTableState) + rw.On("GetAccountInfoWithOpts", mock.Anything, lookupTablePubkey, mock.Anything).Return(&rpc.GetAccountInfoResult{ + RPCContext: rpc.RPCContext{}, + Value: &rpc.Account{Data: rpc.DataBytesOrJSONFromBytes(lookupTableStateBytes)}, + }, nil) +} diff --git a/pkg/solana/chainwriter/helpers.go b/pkg/solana/chainwriter/helpers.go new file mode 100644 index 000000000..bc256c60a --- /dev/null +++ b/pkg/solana/chainwriter/helpers.go @@ -0,0 +1,202 @@ +package chainwriter + +import ( + "context" + "crypto/sha256" + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" +) + +// GetValuesAtLocation parses through nested types and arrays to find all locations of values +func GetValuesAtLocation(args any, location string) ([][]byte, error) { + var vals [][]byte + path := strings.Split(location, ".") + + addressList, err := traversePath(args, path) + if err != nil { + return nil, err + } + + for _, value := range addressList { + if byteArray, ok := value.([]byte); ok { + vals = append(vals, byteArray) + } else if address, ok := value.(solana.PublicKey); ok { + vals = append(vals, address.Bytes()) + } else { + return nil, fmt.Errorf("invalid value format at path: %s", location) + } + } + + return vals, nil +} + +func GetDebugIDAtLocation(args any, location string) (string, error) { + debugIDList, err := GetValueAtLocation(args, location) + if err != nil { + return "", err + } + + // there should only be one debug ID, others will be ignored. + debugID := string(debugIDList[0]) + + return debugID, nil +} + +func GetValueAtLocation(args any, location string) ([][]byte, error) { + path := strings.Split(location, ".") + + valueList, err := traversePath(args, path) + if err != nil { + return nil, err + } + + var values [][]byte + for _, value := range valueList { + byteArray, ok := value.([]byte) + if !ok { + return nil, fmt.Errorf("invalid value format at path: %s", location) + } + values = append(values, byteArray) + } + + return values, nil +} + +func errorWithDebugID(err error, debugID string) error { + if debugID == "" { + return err + } + return fmt.Errorf("Debug ID: %s: Error: %s", debugID, err) +} + +// traversePath recursively traverses the given structure based on the provided path. +func traversePath(data any, path []string) ([]any, error) { + if len(path) == 0 { + return []any{data}, nil + } + + var result []any + + val := reflect.ValueOf(data) + + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + + switch val.Kind() { + case reflect.Struct: + field := val.FieldByName(path[0]) + if !field.IsValid() { + return nil, errors.New("field not found: " + path[0]) + } + return traversePath(field.Interface(), path[1:]) + + case reflect.Slice, reflect.Array: + for i := 0; i < val.Len(); i++ { + element := val.Index(i).Interface() + elements, err := traversePath(element, path) + if err == nil { + result = append(result, elements...) + } + } + if len(result) > 0 { + return result, nil + } + return nil, errors.New("no matching field found in array") + + case reflect.Map: + key := reflect.ValueOf(path[0]) + value := val.MapIndex(key) + if !value.IsValid() { + return nil, errors.New("key not found: " + path[0]) + } + return traversePath(value.Interface(), path[1:]) + default: + if len(path) == 1 && val.Kind() == reflect.Slice && val.Type().Elem().Kind() == reflect.Uint8 { + return []any{val.Interface()}, nil + } + return nil, errors.New("unexpected type encountered at path: " + path[0]) + } +} + +func InitializeDataAccount( + ctx context.Context, + t *testing.T, + client *rpc.Client, + programID solana.PublicKey, + admin solana.PrivateKey, + lookupTable solana.PublicKey, +) { + pda, _, err := solana.FindProgramAddress([][]byte{[]byte("data")}, programID) + require.NoError(t, err) + + discriminator := GetDiscriminator("initialize") + + instructionData := append(discriminator[:], lookupTable.Bytes()...) + + instruction := solana.NewInstruction( + programID, + solana.AccountMetaSlice{ + solana.Meta(pda).WRITE(), + solana.Meta(admin.PublicKey()).SIGNER().WRITE(), + solana.Meta(solana.SystemProgramID), + }, + instructionData, + ) + + // Send and confirm the transaction + utils.SendAndConfirm(ctx, t, client, []solana.Instruction{instruction}, admin, rpc.CommitmentFinalized) +} + +func GetDiscriminator(instruction string) [8]byte { + fullHash := sha256.Sum256([]byte("global:" + instruction)) + var discriminator [8]byte + copy(discriminator[:], fullHash[:8]) + return discriminator +} + +func GetRandomPubKey(t *testing.T) solana.PublicKey { + privKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + return privKey.PublicKey() +} + +func CreateTestPubKeys(t *testing.T, num int) solana.PublicKeySlice { + addresses := make([]solana.PublicKey, num) + for i := 0; i < num; i++ { + addresses[i] = GetRandomPubKey(t) + } + return addresses +} + +func CreateTestLookupTable(ctx context.Context, t *testing.T, c *rpc.Client, sender solana.PrivateKey, addresses []solana.PublicKey) solana.PublicKey { + // Create lookup tables + slot, serr := c.GetSlot(ctx, rpc.CommitmentFinalized) + require.NoError(t, serr) + table, instruction, ierr := utils.NewCreateLookupTableInstruction( + sender.PublicKey(), + sender.PublicKey(), + slot, + ) + require.NoError(t, ierr) + utils.SendAndConfirm(ctx, t, c, []solana.Instruction{instruction}, sender, rpc.CommitmentConfirmed) + + // add entries to lookup table + utils.SendAndConfirm(ctx, t, c, []solana.Instruction{ + utils.NewExtendLookupTableInstruction( + table, sender.PublicKey(), sender.PublicKey(), + addresses, + ), + }, sender, rpc.CommitmentConfirmed) + + return table +} diff --git a/pkg/solana/chainwriter/lookups.go b/pkg/solana/chainwriter/lookups.go new file mode 100644 index 000000000..1947b060d --- /dev/null +++ b/pkg/solana/chainwriter/lookups.go @@ -0,0 +1,357 @@ +package chainwriter + +import ( + "context" + "fmt" + "reflect" + + ag_binary "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + addresslookuptable "github.com/gagliardetto/solana-go/programs/address-lookup-table" + "github.com/gagliardetto/solana-go/rpc" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" +) + +type Lookup interface { + Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) +} + +// AccountConstant represents a fixed address, provided in Base58 format, converted into a `solana.PublicKey`. +type AccountConstant struct { + Name string + Address string + IsSigner bool + IsWritable bool +} + +// AccountLookup dynamically derives an account address from args using a specified location path. +type AccountLookup struct { + Name string + Location string + IsSigner bool + IsWritable bool +} + +// PDALookups generates Program Derived Addresses (PDA) by combining a derived public key with one or more seeds. +type PDALookups struct { + Name string + // The public key of the PDA to be combined with seeds. If there are multiple PublicKeys + // there will be multiple PDAs generated by combining each PublicKey with the seeds. + PublicKey Lookup + // Seeds to be derived from an additional lookup + Seeds []Lookup + IsSigner bool + IsWritable bool + // OPTIONAL: On-chain location and type of desired data from PDA (e.g. a sub-account of the data account) + InternalField InternalField +} + +type InternalField struct { + Type reflect.Type + Location string +} + +type ValueLookup struct { + Location string +} + +// LookupTables represents a list of lookup tables that are used to derive addresses for a program. +type LookupTables struct { + DerivedLookupTables []DerivedLookupTable + StaticLookupTables []string +} + +// DerivedLookupTable represents a lookup table that is used to derive addresses for a program. +type DerivedLookupTable struct { + Name string + Accounts Lookup +} + +// AccountsFromLookupTable extracts accounts from a lookup table that was previously read and stored in memory. +type AccountsFromLookupTable struct { + LookupTableName string + IncludeIndexes []int +} + +func (ac AccountConstant) Resolve(_ context.Context, _ any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { + address, err := solana.PublicKeyFromBase58(ac.Address) + if err != nil { + return nil, fmt.Errorf("error getting account from constant: %w", err) + } + return []*solana.AccountMeta{ + { + PublicKey: address, + IsSigner: ac.IsSigner, + IsWritable: ac.IsWritable, + }, + }, nil +} + +func (al AccountLookup) Resolve(_ context.Context, args any, _ map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { + derivedValues, err := GetValuesAtLocation(args, al.Location) + if err != nil { + return nil, fmt.Errorf("error getting account from lookup: %w", err) + } + + var metas []*solana.AccountMeta + for _, address := range derivedValues { + metas = append(metas, &solana.AccountMeta{ + PublicKey: solana.PublicKeyFromBytes(address), + IsSigner: al.IsSigner, + IsWritable: al.IsWritable, + }) + } + return metas, nil +} + +func (alt AccountsFromLookupTable) Resolve(_ context.Context, _ any, derivedTableMap map[string]map[string][]*solana.AccountMeta, _ client.Reader) ([]*solana.AccountMeta, error) { + // Fetch the inner map for the specified lookup table name + innerMap, ok := derivedTableMap[alt.LookupTableName] + if !ok { + return nil, fmt.Errorf("lookup table not found: %s", alt.LookupTableName) + } + + var result []*solana.AccountMeta + + // If no indices are specified, include all addresses + if len(alt.IncludeIndexes) == 0 { + for _, metas := range innerMap { + result = append(result, metas...) + } + return result, nil + } + + // Otherwise, include only addresses at the specified indices + for publicKey, metas := range innerMap { + for _, index := range alt.IncludeIndexes { + if index < 0 || index >= len(metas) { + return nil, fmt.Errorf("invalid index %d for account %s in lookup table %s", index, publicKey, alt.LookupTableName) + } + result = append(result, metas[index]) + } + } + + return result, nil +} + +func (pda PDALookups) Resolve(ctx context.Context, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([]*solana.AccountMeta, error) { + publicKeys, err := GetAddresses(ctx, args, []Lookup{pda.PublicKey}, derivedTableMap, reader) + if err != nil { + return nil, fmt.Errorf("error getting public key for PDALookups: %w", err) + } + + seeds, err := getSeedBytes(ctx, pda, args, derivedTableMap, reader) + if err != nil { + return nil, fmt.Errorf("error getting seeds for PDALookups: %w", err) + } + + pdas, err := generatePDAs(publicKeys, seeds, pda) + if err != nil { + return nil, fmt.Errorf("error generating PDAs: %w", err) + } + + if pda.InternalField.Location == "" { + return pdas, nil + } + + // If a decoded location is specified, fetch the data at that location + var result []*solana.AccountMeta + for _, accountMeta := range pdas { + accountInfo, err := reader.GetAccountInfoWithOpts(ctx, accountMeta.PublicKey, &rpc.GetAccountInfoOpts{ + Encoding: "base64", + Commitment: rpc.CommitmentFinalized, + }) + fmt.Printf("Accounts Info: %+v", accountInfo) + + if err != nil || accountInfo == nil || accountInfo.Value == nil { + return nil, fmt.Errorf("error fetching account info for PDA account: %s, error: %w", accountMeta.PublicKey.String(), err) + } + + decoded, err := decodeBorshIntoType(accountInfo.GetBinary(), pda.InternalField.Type) + if err != nil { + return nil, fmt.Errorf("error decoding Borsh data dynamically: %w", err) + } + + value, err := GetValuesAtLocation(decoded, pda.InternalField.Location) + if err != nil { + return nil, fmt.Errorf("error getting value at location: %w", err) + } + if len(value) > 1 { + return nil, fmt.Errorf("multiple values found at location: %s", pda.InternalField.Location) + } + + result = append(result, &solana.AccountMeta{ + PublicKey: solana.PublicKeyFromBytes(value[0]), + IsSigner: accountMeta.IsSigner, + IsWritable: accountMeta.IsWritable, + }) + } + return result, nil +} + +func decodeBorshIntoType(data []byte, typ reflect.Type) (interface{}, error) { + // Ensure the type is a struct + if typ.Kind() != reflect.Struct { + return nil, fmt.Errorf("provided type is not a struct: %s", typ.Kind()) + } + + // Create a new instance of the type + instance := reflect.New(typ).Interface() + + // Decode using Borsh + err := ag_binary.NewBorshDecoder(data).Decode(instance) + if err != nil { + return nil, fmt.Errorf("error decoding Borsh data: %w", err) + } + + // Return the underlying value (not a pointer) + return reflect.ValueOf(instance).Elem().Interface(), nil +} + +// getSeedBytes extracts the seeds for the PDALookups. +// It handles both AddressSeeds (which are public keys) and ValueSeeds (which are byte arrays from input args). +func getSeedBytes(ctx context.Context, lookup PDALookups, args any, derivedTableMap map[string]map[string][]*solana.AccountMeta, reader client.Reader) ([][]byte, error) { + var seedBytes [][]byte + + for _, seed := range lookup.Seeds { + if lookupSeed, ok := seed.(AccountLookup); ok { + // Get value from a location (This doens't have to be an address, it can be any value) + bytes, err := GetValuesAtLocation(args, lookupSeed.Location) + if err != nil { + return nil, fmt.Errorf("error getting address seed: %w", err) + } + seedBytes = append(seedBytes, bytes...) + } else { + // Get address seeds from the lookup + seedAddresses, err := GetAddresses(ctx, args, []Lookup{seed}, derivedTableMap, reader) + if err != nil { + return nil, fmt.Errorf("error getting address seed: %w", err) + } + + // Add each address seed as bytes + for _, address := range seedAddresses { + seedBytes = append(seedBytes, address.PublicKey.Bytes()) + } + } + } + + return seedBytes, nil +} + +// generatePDAs generates program-derived addresses (PDAs) from public keys and seeds. +func generatePDAs(publicKeys []*solana.AccountMeta, seeds [][]byte, lookup PDALookups) ([]*solana.AccountMeta, error) { + if len(seeds) > 1 && len(publicKeys) > 1 { + return nil, fmt.Errorf("multiple public keys and multiple seeds are not allowed") + } + + var addresses []*solana.AccountMeta + for _, publicKeyMeta := range publicKeys { + address, _, err := solana.FindProgramAddress(seeds, publicKeyMeta.PublicKey) + if err != nil { + return nil, fmt.Errorf("error finding program address: %w", err) + } + addresses = append(addresses, &solana.AccountMeta{ + PublicKey: address, + IsSigner: lookup.IsSigner, + IsWritable: lookup.IsWritable, + }) + } + return addresses, nil +} + +func (s *SolanaChainWriterService) ResolveLookupTables(ctx context.Context, args any, lookupTables LookupTables) (map[string]map[string][]*solana.AccountMeta, map[solana.PublicKey]solana.PublicKeySlice, error) { + derivedTableMap := make(map[string]map[string][]*solana.AccountMeta) + staticTableMap := make(map[solana.PublicKey]solana.PublicKeySlice) + + // Read derived lookup tables + for _, derivedLookup := range lookupTables.DerivedLookupTables { + lookupTableMap, _, err := s.LoadTable(ctx, args, derivedLookup, s.reader, derivedTableMap) + if err != nil { + return nil, nil, fmt.Errorf("error loading derived lookup table: %w", err) + } + + // Merge the loaded table map into the result + for tableName, innerMap := range lookupTableMap { + if derivedTableMap[tableName] == nil { + derivedTableMap[tableName] = make(map[string][]*solana.AccountMeta) + } + for accountKey, metas := range innerMap { + derivedTableMap[tableName][accountKey] = metas + } + } + } + + // Read static lookup tables + for _, staticTable := range lookupTables.StaticLookupTables { + // Parse the static table address + tableAddress, err := solana.PublicKeyFromBase58(staticTable) + if err != nil { + return nil, nil, fmt.Errorf("invalid static lookup table address: %s, error: %w", staticTable, err) + } + + addressses, err := getLookupTableAddress(ctx, s.reader, tableAddress) + if err != nil { + return nil, nil, fmt.Errorf("error fetching static lookup table address: %w", err) + } + staticTableMap[tableAddress] = addressses + } + + return derivedTableMap, staticTableMap, nil +} + +func (s *SolanaChainWriterService) LoadTable(ctx context.Context, args any, rlt DerivedLookupTable, reader client.Reader, derivedTableMap map[string]map[string][]*solana.AccountMeta) (map[string]map[string][]*solana.AccountMeta, []*solana.AccountMeta, error) { + // Resolve all addresses specified by the identifier + lookupTableAddresses, err := GetAddresses(ctx, args, []Lookup{rlt.Accounts}, nil, reader) + if err != nil { + return nil, nil, fmt.Errorf("error resolving addresses for lookup table: %w", err) + } + + resultMap := make(map[string]map[string][]*solana.AccountMeta) + var lookupTableMetas []*solana.AccountMeta + + // Iterate over each address of the lookup table + for _, addressMeta := range lookupTableAddresses { + // Fetch account info + addresses, err := getLookupTableAddress(ctx, reader, addressMeta.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("error fetching lookup table address: %w", err) + } + + // Create the inner map for this lookup table + if resultMap[rlt.Name] == nil { + resultMap[rlt.Name] = make(map[string][]*solana.AccountMeta) + } + + // Populate the inner map (keyed by the account public key) + for _, addr := range addresses { + resultMap[rlt.Name][addressMeta.PublicKey.String()] = append(resultMap[rlt.Name][addressMeta.PublicKey.String()], &solana.AccountMeta{ + PublicKey: addr, + IsSigner: addressMeta.IsSigner, + IsWritable: addressMeta.IsWritable, + }) + } + + // Add the current lookup table address to the list of metas + lookupTableMetas = append(lookupTableMetas, addressMeta) + } + + return resultMap, lookupTableMetas, nil +} + +func getLookupTableAddress(ctx context.Context, reader client.Reader, tableAddress solana.PublicKey) (solana.PublicKeySlice, error) { + // Fetch the account info for the static table + accountInfo, err := reader.GetAccountInfoWithOpts(ctx, tableAddress, &rpc.GetAccountInfoOpts{ + Encoding: "base64", + Commitment: rpc.CommitmentFinalized, + }) + + if err != nil || accountInfo == nil || accountInfo.Value == nil { + return nil, fmt.Errorf("error fetching account info for table: %s, error: %w", tableAddress.String(), err) + } + alt, err := addresslookuptable.DecodeAddressLookupTableState(accountInfo.GetBinary()) + if err != nil { + return nil, fmt.Errorf("error decoding address lookup table state: %w", err) + } + return alt.Addresses, nil +} diff --git a/pkg/solana/chainwriter/lookups_test.go b/pkg/solana/chainwriter/lookups_test.go new file mode 100644 index 000000000..53972feac --- /dev/null +++ b/pkg/solana/chainwriter/lookups_test.go @@ -0,0 +1,447 @@ +package chainwriter_test + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + commonutils "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/chainwriter" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" +) + +type TestArgs struct { + Inner []InnerArgs +} + +type InnerArgs struct { + Address []byte +} + +type DataAccount struct { + Discriminator [8]byte + Version uint8 + Administrator solana.PublicKey + PendingAdministrator solana.PublicKey + LookupTable solana.PublicKey +} + +func TestAccountContant(t *testing.T) { + t.Run("AccountConstant resolves valid address", func(t *testing.T) { + expectedAddr := chainwriter.GetRandomPubKey(t) + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: expectedAddr, + IsSigner: true, + IsWritable: true, + }, + } + constantConfig := chainwriter.AccountConstant{ + Name: "TestAccount", + Address: expectedAddr.String(), + IsSigner: true, + IsWritable: true, + } + result, err := constantConfig.Resolve(tests.Context(t), nil, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) +} +func TestAccountLookups(t *testing.T) { + ctx := tests.Context(t) + t.Run("AccountLookup resolves valid address with just one address", func(t *testing.T) { + expectedAddr := chainwriter.GetRandomPubKey(t) + testArgs := TestArgs{ + Inner: []InnerArgs{ + {Address: expectedAddr.Bytes()}, + }, + } + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: expectedAddr, + IsSigner: true, + IsWritable: true, + }, + } + + lookupConfig := chainwriter.AccountLookup{ + Name: "TestAccount", + Location: "Inner.Address", + IsSigner: true, + IsWritable: true, + } + result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) + + t.Run("AccountLookup resolves valid address with just multiple addresses", func(t *testing.T) { + expectedAddr1 := chainwriter.GetRandomPubKey(t) + expectedAddr2 := chainwriter.GetRandomPubKey(t) + + testArgs := TestArgs{ + Inner: []InnerArgs{ + {Address: expectedAddr1.Bytes()}, + {Address: expectedAddr2.Bytes()}, + }, + } + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: expectedAddr1, + IsSigner: true, + IsWritable: true, + }, + { + PublicKey: expectedAddr2, + IsSigner: true, + IsWritable: true, + }, + } + + lookupConfig := chainwriter.AccountLookup{ + Name: "TestAccount", + Location: "Inner.Address", + IsSigner: true, + IsWritable: true, + } + result, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + require.NoError(t, err) + for i, meta := range result { + require.Equal(t, expectedMeta[i], meta) + } + }) + + t.Run("AccountLookup fails when address isn't in args", func(t *testing.T) { + expectedAddr := chainwriter.GetRandomPubKey(t) + + testArgs := TestArgs{ + Inner: []InnerArgs{ + {Address: expectedAddr.Bytes()}, + }, + } + lookupConfig := chainwriter.AccountLookup{ + Name: "InvalidAccount", + Location: "Invalid.Directory", + IsSigner: true, + IsWritable: true, + } + _, err := lookupConfig.Resolve(ctx, testArgs, nil, nil) + require.Error(t, err) + }) +} + +func TestPDALookups(t *testing.T) { + programID := solana.SystemProgramID + + t.Run("PDALookup resolves valid PDA with constant address seeds", func(t *testing.T) { + seed := chainwriter.GetRandomPubKey(t) + + pda, _, err := solana.FindProgramAddress([][]byte{seed.Bytes()}, programID) + require.NoError(t, err) + + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: pda, + IsSigner: false, + IsWritable: true, + }, + } + + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + chainwriter.AccountConstant{Name: "seed", Address: seed.String()}, + }, + IsSigner: false, + IsWritable: true, + } + + ctx := context.Background() + result, err := pdaLookup.Resolve(ctx, nil, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) + t.Run("PDALookup resolves valid PDA with non-address lookup seeds", func(t *testing.T) { + seed1 := []byte("test_seed") + seed2 := []byte("another_seed") + + pda, _, err := solana.FindProgramAddress([][]byte{seed1, seed2}, programID) + require.NoError(t, err) + + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: pda, + IsSigner: false, + IsWritable: true, + }, + } + + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + chainwriter.AccountLookup{Name: "seed1", Location: "test_seed"}, + chainwriter.AccountLookup{Name: "seed2", Location: "another_seed"}, + }, + IsSigner: false, + IsWritable: true, + } + + ctx := context.Background() + args := map[string]interface{}{ + "test_seed": seed1, + "another_seed": seed2, + } + + result, err := pdaLookup.Resolve(ctx, args, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) + + t.Run("PDALookup fails with missing seeds", func(t *testing.T) { + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + chainwriter.AccountLookup{Name: "seed1", Location: "MissingSeed"}, + }, + IsSigner: false, + IsWritable: true, + } + + ctx := context.Background() + args := map[string]interface{}{ + "test_seed": []byte("data"), + } + + _, err := pdaLookup.Resolve(ctx, args, nil, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "key not found") + }) + + t.Run("PDALookup resolves valid PDA with address lookup seeds", func(t *testing.T) { + seed1 := chainwriter.GetRandomPubKey(t) + seed2 := chainwriter.GetRandomPubKey(t) + + pda, _, err := solana.FindProgramAddress([][]byte{seed1.Bytes(), seed2.Bytes()}, programID) + require.NoError(t, err) + + expectedMeta := []*solana.AccountMeta{ + { + PublicKey: pda, + IsSigner: false, + IsWritable: true, + }, + } + + pdaLookup := chainwriter.PDALookups{ + Name: "TestPDA", + PublicKey: chainwriter.AccountConstant{Name: "ProgramID", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + chainwriter.AccountLookup{Name: "seed1", Location: "test_seed"}, + chainwriter.AccountLookup{Name: "seed2", Location: "another_seed"}, + }, + IsSigner: false, + IsWritable: true, + } + + ctx := context.Background() + args := map[string]interface{}{ + "test_seed": seed1, + "another_seed": seed2, + } + + result, err := pdaLookup.Resolve(ctx, args, nil, nil) + require.NoError(t, err) + require.Equal(t, expectedMeta, result) + }) +} + +func TestLookupTables(t *testing.T) { + ctx := tests.Context(t) + + sender, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + + url := utils.SetupTestValidatorWithAnchorPrograms(t, utils.PathToAnchorConfig, sender.PublicKey().String()) + rpcClient := rpc.New(url) + + utils.FundAccounts(ctx, []solana.PrivateKey{sender}, rpcClient, t) + + cfg := config.NewDefault() + solanaClient, err := client.NewClient(url, cfg, 5*time.Second, nil) + require.NoError(t, err) + + loader := commonutils.NewLazyLoad(func() (client.ReaderWriter, error) { return solanaClient, nil }) + mkey := keyMocks.NewSimpleKeystore(t) + lggr := logger.Test(t) + + txm := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) + + cw, err := chainwriter.NewSolanaChainWriterService(solanaClient, txm, nil, chainwriter.ChainWriterConfig{}) + + t.Run("StaticLookup table resolves properly", func(t *testing.T) { + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: nil, + StaticLookupTables: []string{table.String()}, + } + _, staticTableMap, resolveErr := cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.NoError(t, resolveErr) + require.Equal(t, pubKeys, staticTableMap[table]) + }) + t.Run("Derived lookup table resolves properly with constant address", func(t *testing.T) { + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.AccountConstant{ + Name: "TestLookupTable", + Address: table.String(), + IsSigner: true, + IsWritable: true, + }, + }, + }, + StaticLookupTables: nil, + } + derivedTableMap, _, resolveErr := cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.NoError(t, resolveErr) + + addresses, ok := derivedTableMap["DerivedTable"][table.String()] + require.True(t, ok) + for i, address := range addresses { + require.Equal(t, pubKeys[i], address.PublicKey) + } + }) + + t.Run("Derived lookup table fails with invalid address", func(t *testing.T) { + invalidTable := chainwriter.GetRandomPubKey(t) + + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.AccountConstant{ + Name: "InvalidTable", + Address: invalidTable.String(), + IsSigner: true, + IsWritable: true, + }, + }, + }, + StaticLookupTables: nil, + } + + _, _, err = cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.Error(t, err) + require.Contains(t, err.Error(), "error fetching account info for table") // Example error message + }) + + t.Run("Static lookup table fails with invalid address", func(t *testing.T) { + invalidTable := chainwriter.GetRandomPubKey(t) + + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: nil, + StaticLookupTables: []string{invalidTable.String()}, + } + + _, _, err = cw.ResolveLookupTables(ctx, nil, lookupConfig) + require.Error(t, err) + require.Contains(t, err.Error(), "error fetching account info for table") // Example error message + }) + + t.Run("Derived lookup table resolves properly with account lookup address", func(t *testing.T) { + pubKeys := chainwriter.CreateTestPubKeys(t, 8) + table := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, pubKeys) + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.AccountLookup{ + Name: "TestLookupTable", + Location: "Inner.Address", + IsSigner: true, + }, + }, + }, + StaticLookupTables: nil, + } + + testArgs := TestArgs{ + Inner: []InnerArgs{ + {Address: table.Bytes()}, + }, + } + + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, testArgs, lookupConfig) + require.NoError(t, err) + + addresses, ok := derivedTableMap["DerivedTable"][table.String()] + require.True(t, ok) + for i, address := range addresses { + require.Equal(t, pubKeys[i], address.PublicKey) + } + }) + + t.Run("Derived lookup table resolves properly with PDALookup address", func(t *testing.T) { + // Deployed write_test contract + programID := solana.MustPublicKeyFromBase58("39vbQVpEMtZtg3e6ZSE7nBSzmNZptmW45WnLkbqEe4TU") + + lookupKeys := chainwriter.CreateTestPubKeys(t, 5) + lookupTable := chainwriter.CreateTestLookupTable(ctx, t, rpcClient, sender, lookupKeys) + + chainwriter.InitializeDataAccount(ctx, t, rpcClient, programID, sender, lookupTable) + + args := map[string]interface{}{ + "seed1": []byte("data"), + } + + lookupConfig := chainwriter.LookupTables{ + DerivedLookupTables: []chainwriter.DerivedLookupTable{ + { + Name: "DerivedTable", + Accounts: chainwriter.PDALookups{ + Name: "DataAccountPDA", + PublicKey: chainwriter.AccountConstant{Name: "WriteTest", Address: programID.String()}, + Seeds: []chainwriter.Lookup{ + chainwriter.AccountLookup{Name: "seed1", Location: "seed1"}, + }, + IsSigner: false, + IsWritable: false, + InternalField: chainwriter.InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }, + StaticLookupTables: nil, + } + + derivedTableMap, _, err := cw.ResolveLookupTables(ctx, args, lookupConfig) + require.NoError(t, err) + + addresses, ok := derivedTableMap["DerivedTable"][lookupTable.String()] + require.True(t, ok) + for i, address := range addresses { + require.Equal(t, lookupKeys[i], address.PublicKey) + } + }) +} diff --git a/pkg/solana/client/test_helpers.go b/pkg/solana/client/test_helpers.go index 5bb8b1cde..8d5ab4f88 100644 --- a/pkg/solana/client/test_helpers.go +++ b/pkg/solana/client/test_helpers.go @@ -66,6 +66,7 @@ func SetupLocalSolNodeWithFlags(t *testing.T, flags ...string) (string, string) out, err := client.GetHealth(tests.Context(t)) if err != nil || out != rpc.HealthOk { t.Logf("API server not ready yet (attempt %d)\n", i+1) + t.Logf("Error from API server: %v\n", err) continue } ready = true diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index e48c6022f..71e2f7f06 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -60,6 +60,28 @@ func NewNamedModifierCodec(original types.RemoteCodec, itemType string, modifier return modCodec, err } +func NewIDLInstructionsCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { + typeCodecs := make(encodings.LenientCodecFromTypeCodec) + caser := cases.Title(language.English) + refs := &codecRefs{ + builder: builder, + codecs: make(map[string]encodings.TypeCodec), + typeDefs: idl.Types, + dependencies: make(map[string][]string), + } + + for _, instruction := range idl.Instructions { + name, instCodec, err := asStruct(instruction.Args, refs, instruction.Name, caser, false) + if err != nil { + return nil, err + } + + typeCodecs[name] = instCodec + } + + return typeCodecs, nil +} + // NewIDLAccountCodec is for Anchor custom types func NewIDLAccountCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { return newIDLCoded(idl, builder, idl.Accounts, true) @@ -115,7 +137,7 @@ func createNamedCodec( switch def.Type.Kind { case IdlTypeDefTyKindStruct: - return asStruct(def, refs, name, caser, includeDiscriminator) + return asStruct(*def.Type.Fields, refs, name, caser, includeDiscriminator) case IdlTypeDefTyKindEnum: variants := def.Type.Variants if !variants.IsAllUint8() { @@ -129,7 +151,7 @@ func createNamedCodec( } func asStruct( - def IdlTypeDef, + fields []IdlField, refs *codecRefs, name string, // name is the struct name and can be used in dependency checks caser cases.Caser, @@ -139,13 +161,13 @@ func asStruct( if includeDiscriminator { desLen = 1 } - named := make([]encodings.NamedTypeCodec, len(*def.Type.Fields)+desLen) + named := make([]encodings.NamedTypeCodec, len(fields)+desLen) if includeDiscriminator { named[0] = encodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)} } - for idx, field := range *def.Type.Fields { + for idx, field := range fields { fieldName := field.Name typedCodec, err := processFieldType(name, field.Type, refs) diff --git a/pkg/solana/relay.go b/pkg/solana/relay.go index 1f2fbdffd..f925434d2 100644 --- a/pkg/solana/relay.go +++ b/pkg/solana/relay.go @@ -19,12 +19,13 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) var _ TxManager = (*txm.Txm)(nil) type TxManager interface { - Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...txm.SetTxConfig) error + Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error } var _ relaytypes.Relayer = &Relayer{} //nolint:staticcheck diff --git a/pkg/solana/transmitter_test.go b/pkg/solana/transmitter_test.go index 1d058d36a..b4372515a 100644 --- a/pkg/solana/transmitter_test.go +++ b/pkg/solana/transmitter_test.go @@ -17,7 +17,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" clientmocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/client/mocks" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" - "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) // custom mock txm instead of mockery generated because SetTxConfig causes circular imports @@ -27,7 +27,7 @@ type verifyTxSize struct { s *solana.PrivateKey } -func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ ...txm.SetTxConfig) error { +func (txm verifyTxSize) Enqueue(_ context.Context, _ string, tx *solana.Transaction, txID *string, _ ...txmutils.SetTxConfig) error { // additional components that transaction manager adds to the transaction require.NoError(txm.t, fees.SetComputeUnitPrice(tx, 0)) require.NoError(txm.t, fees.SetComputeUnitLimit(tx, 0)) diff --git a/pkg/solana/txm/mocks/tx_manager.go b/pkg/solana/txm/mocks/tx_manager.go new file mode 100644 index 000000000..50806a4da --- /dev/null +++ b/pkg/solana/txm/mocks/tx_manager.go @@ -0,0 +1,390 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + solana "github.com/gagliardetto/solana-go" + mock "github.com/stretchr/testify/mock" + + types "github.com/smartcontractkit/chainlink-common/pkg/types" + + utils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" +) + +// TxManager is an autogenerated mock type for the TxManager type +type TxManager struct { + mock.Mock +} + +type TxManager_Expecter struct { + mock *mock.Mock +} + +func (_m *TxManager) EXPECT() *TxManager_Expecter { + return &TxManager_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *TxManager) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type TxManager_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *TxManager_Expecter) Close() *TxManager_Close_Call { + return &TxManager_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *TxManager_Close_Call) Run(run func()) *TxManager_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Close_Call) Return(_a0 error) *TxManager_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Close_Call) RunAndReturn(run func() error) *TxManager_Close_Call { + _c.Call.Return(run) + return _c +} + +// Enqueue provides a mock function with given fields: ctx, accountID, tx, txID, txCfgs +func (_m *TxManager) Enqueue(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...utils.SetTxConfig) error { + _va := make([]interface{}, len(txCfgs)) + for _i := range txCfgs { + _va[_i] = txCfgs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, accountID, tx, txID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Enqueue") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, *solana.Transaction, *string, ...utils.SetTxConfig) error); ok { + r0 = rf(ctx, accountID, tx, txID, txCfgs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Enqueue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Enqueue' +type TxManager_Enqueue_Call struct { + *mock.Call +} + +// Enqueue is a helper method to define mock.On call +// - ctx context.Context +// - accountID string +// - tx *solana.Transaction +// - txID *string +// - txCfgs ...utils.SetTxConfig +func (_e *TxManager_Expecter) Enqueue(ctx interface{}, accountID interface{}, tx interface{}, txID interface{}, txCfgs ...interface{}) *TxManager_Enqueue_Call { + return &TxManager_Enqueue_Call{Call: _e.mock.On("Enqueue", + append([]interface{}{ctx, accountID, tx, txID}, txCfgs...)...)} +} + +func (_c *TxManager_Enqueue_Call) Run(run func(ctx context.Context, accountID string, tx *solana.Transaction, txID *string, txCfgs ...utils.SetTxConfig)) *TxManager_Enqueue_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]utils.SetTxConfig, len(args)-4) + for i, a := range args[4:] { + if a != nil { + variadicArgs[i] = a.(utils.SetTxConfig) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(*solana.Transaction), args[3].(*string), variadicArgs...) + }) + return _c +} + +func (_c *TxManager_Enqueue_Call) Return(_a0 error) *TxManager_Enqueue_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Enqueue_Call) RunAndReturn(run func(context.Context, string, *solana.Transaction, *string, ...utils.SetTxConfig) error) *TxManager_Enqueue_Call { + _c.Call.Return(run) + return _c +} + +// GetTransactionStatus provides a mock function with given fields: ctx, transactionID +func (_m *TxManager) GetTransactionStatus(ctx context.Context, transactionID string) (types.TransactionStatus, error) { + ret := _m.Called(ctx, transactionID) + + if len(ret) == 0 { + panic("no return value specified for GetTransactionStatus") + } + + var r0 types.TransactionStatus + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.TransactionStatus, error)); ok { + return rf(ctx, transactionID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.TransactionStatus); ok { + r0 = rf(ctx, transactionID) + } else { + r0 = ret.Get(0).(types.TransactionStatus) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, transactionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// TxManager_GetTransactionStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTransactionStatus' +type TxManager_GetTransactionStatus_Call struct { + *mock.Call +} + +// GetTransactionStatus is a helper method to define mock.On call +// - ctx context.Context +// - transactionID string +func (_e *TxManager_Expecter) GetTransactionStatus(ctx interface{}, transactionID interface{}) *TxManager_GetTransactionStatus_Call { + return &TxManager_GetTransactionStatus_Call{Call: _e.mock.On("GetTransactionStatus", ctx, transactionID)} +} + +func (_c *TxManager_GetTransactionStatus_Call) Run(run func(ctx context.Context, transactionID string)) *TxManager_GetTransactionStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *TxManager_GetTransactionStatus_Call) Return(_a0 types.TransactionStatus, _a1 error) *TxManager_GetTransactionStatus_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *TxManager_GetTransactionStatus_Call) RunAndReturn(run func(context.Context, string) (types.TransactionStatus, error)) *TxManager_GetTransactionStatus_Call { + _c.Call.Return(run) + return _c +} + +// HealthReport provides a mock function with given fields: +func (_m *TxManager) HealthReport() map[string]error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for HealthReport") + } + + var r0 map[string]error + if rf, ok := ret.Get(0).(func() map[string]error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]error) + } + } + + return r0 +} + +// TxManager_HealthReport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HealthReport' +type TxManager_HealthReport_Call struct { + *mock.Call +} + +// HealthReport is a helper method to define mock.On call +func (_e *TxManager_Expecter) HealthReport() *TxManager_HealthReport_Call { + return &TxManager_HealthReport_Call{Call: _e.mock.On("HealthReport")} +} + +func (_c *TxManager_HealthReport_Call) Run(run func()) *TxManager_HealthReport_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_HealthReport_Call) Return(_a0 map[string]error) *TxManager_HealthReport_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_HealthReport_Call) RunAndReturn(run func() map[string]error) *TxManager_HealthReport_Call { + _c.Call.Return(run) + return _c +} + +// Name provides a mock function with given fields: +func (_m *TxManager) Name() string { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Name") + } + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// TxManager_Name_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Name' +type TxManager_Name_Call struct { + *mock.Call +} + +// Name is a helper method to define mock.On call +func (_e *TxManager_Expecter) Name() *TxManager_Name_Call { + return &TxManager_Name_Call{Call: _e.mock.On("Name")} +} + +func (_c *TxManager_Name_Call) Run(run func()) *TxManager_Name_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Name_Call) Return(_a0 string) *TxManager_Name_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Name_Call) RunAndReturn(run func() string) *TxManager_Name_Call { + _c.Call.Return(run) + return _c +} + +// Ready provides a mock function with given fields: +func (_m *TxManager) Ready() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Ready") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Ready_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ready' +type TxManager_Ready_Call struct { + *mock.Call +} + +// Ready is a helper method to define mock.On call +func (_e *TxManager_Expecter) Ready() *TxManager_Ready_Call { + return &TxManager_Ready_Call{Call: _e.mock.On("Ready")} +} + +func (_c *TxManager_Ready_Call) Run(run func()) *TxManager_Ready_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *TxManager_Ready_Call) Return(_a0 error) *TxManager_Ready_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Ready_Call) RunAndReturn(run func() error) *TxManager_Ready_Call { + _c.Call.Return(run) + return _c +} + +// Start provides a mock function with given fields: _a0 +func (_m *TxManager) Start(_a0 context.Context) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Start") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// TxManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start' +type TxManager_Start_Call struct { + *mock.Call +} + +// Start is a helper method to define mock.On call +// - _a0 context.Context +func (_e *TxManager_Expecter) Start(_a0 interface{}) *TxManager_Start_Call { + return &TxManager_Start_Call{Call: _e.mock.On("Start", _a0)} +} + +func (_c *TxManager_Start_Call) Run(run func(_a0 context.Context)) *TxManager_Start_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *TxManager_Start_Call) Return(_a0 error) *TxManager_Start_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *TxManager_Start_Call) RunAndReturn(run func(context.Context) error) *TxManager_Start_Call { + _c.Call.Return(run) + return _c +} + +// NewTxManager creates a new instance of TxManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTxManager(t interface { + mock.TestingT + Cleanup(func()) +}) *TxManager { + mock := &TxManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/solana/txm/pendingtx.go b/pkg/solana/txm/pendingtx.go index ecae7243b..033b0c16f 100644 --- a/pkg/solana/txm/pendingtx.go +++ b/pkg/solana/txm/pendingtx.go @@ -9,6 +9,8 @@ import ( "github.com/gagliardetto/solana-go" "golang.org/x/exp/maps" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) var ( @@ -37,11 +39,11 @@ type PendingTxContext interface { // OnFinalized marks transaction as Finalized, moves it from the broadcasted or confirmed map to finalized map, removes signatures from signature map to stop confirmation checks OnFinalized(sig solana.Signature, retentionTimeout time.Duration) (string, error) // OnPrebroadcastError adds transaction that has not yet been broadcasted to the finalized/errored map as errored, matches err type using enum - OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error + OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) error // OnError marks transaction as errored, matches err type using enum, moves it from the broadcasted or confirmed map to finalized/errored map, removes signatures from signature map to stop confirmation checks - OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) + OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) (string, error) // GetTxState returns the transaction state for the provided ID if it exists - GetTxState(id string) (TxState, error) + GetTxState(id string) (utils.TxState, error) // TrimFinalizedErroredTxs removes transactions that have reached their retention time TrimFinalizedErroredTxs() int } @@ -49,17 +51,17 @@ type PendingTxContext interface { // finishedTx is used to store info required to track transactions to finality or error type pendingTx struct { tx solana.Transaction - cfg TxConfig + cfg utils.TxConfig signatures []solana.Signature id string createTs time.Time - state TxState + state utils.TxState } // finishedTx is used to store minimal info specifically for finalized or errored transactions for external status checks type finishedTx struct { retentionTs time.Time - state TxState + state utils.TxState } var _ PendingTxContext = &pendingTxContext{} @@ -116,7 +118,7 @@ func (c *pendingTxContext) New(tx pendingTx, sig solana.Signature, cancel contex // add signature to tx tx.signatures = append(tx.signatures, sig) tx.createTs = time.Now() - tx.state = Broadcasted + tx.state = utils.Broadcasted // save to the broadcasted map since transaction was just broadcasted c.broadcastedTxs[tx.id] = tx return "", nil @@ -251,7 +253,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return ErrTransactionNotFound } // Check if tranasction already in processed state - if tx.state == Processed { + if tx.state == utils.Processed { return ErrAlreadyInExpectedState } return nil @@ -271,7 +273,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return id, ErrTransactionNotFound } // update tx state to Processed - tx.state = Processed + tx.state = utils.Processed // save updated tx back to the broadcasted map c.broadcastedTxs[id] = tx return id, nil @@ -286,7 +288,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { return ErrSigDoesNotExist } // Check if transaction already in confirmed state - if tx, exists := c.confirmedTxs[id]; exists && tx.state == Confirmed { + if tx, exists := c.confirmedTxs[id]; exists && tx.state == utils.Confirmed { return ErrAlreadyInExpectedState } // Transactions should only move to confirmed from broadcasted/processed @@ -315,7 +317,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { delete(c.cancelBy, id) } // update tx state to Confirmed - tx.state = Confirmed + tx.state = utils.Confirmed // move tx to confirmed map c.confirmedTxs[id] = tx // remove tx from broadcasted map @@ -379,7 +381,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti return id, nil } finalizedTx := finishedTx{ - state: Finalized, + state: utils.Finalized, retentionTs: time.Now().Add(retentionTimeout), } // move transaction from confirmed to finalized map @@ -388,7 +390,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti }) } -func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, _ TxErrType) error { +func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, _ TxErrType) error { // nothing to do if retention timeout is 0 since transaction is not stored yet. if retentionTimeout == 0 { return nil @@ -429,7 +431,7 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. return err } -func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, _ TxErrType) (string, error) { +func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, _ TxErrType) (string, error) { err := c.withReadLock(func() error { id, sigExists := c.sigToID[sig] if !sigExists { @@ -494,7 +496,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D }) } -func (c *pendingTxContext) GetTxState(id string) (TxState, error) { +func (c *pendingTxContext) GetTxState(id string) (utils.TxState, error) { c.lock.RLock() defer c.lock.RUnlock() if tx, exists := c.broadcastedTxs[id]; exists { @@ -506,7 +508,7 @@ func (c *pendingTxContext) GetTxState(id string) (TxState, error) { if tx, exists := c.finalizedErroredTxs[id]; exists { return tx.state, nil } - return NotFound, fmt.Errorf("failed to find transaction for id: %s", id) + return utils.NotFound, fmt.Errorf("failed to find transaction for id: %s", id) } // TrimFinalizedErroredTxs deletes transactions from the finalized/errored map and the allTxs map after the retention period has passed @@ -617,7 +619,7 @@ func (c *pendingTxContextWithProm) OnFinalized(sig solana.Signature, retentionTi return id, err } -func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, txState TxState, errType TxErrType) (string, error) { +func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) (string, error) { id, err := c.pendingTx.OnError(sig, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed if err == nil { incrementErrorMetrics(errType, c.chainID) @@ -625,7 +627,7 @@ func (c *pendingTxContextWithProm) OnError(sig solana.Signature, retentionTimeou return id, err } -func (c *pendingTxContextWithProm) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState TxState, errType TxErrType) error { +func (c *pendingTxContextWithProm) OnPrebroadcastError(id string, retentionTimeout time.Duration, txState utils.TxState, errType TxErrType) error { err := c.pendingTx.OnPrebroadcastError(id, retentionTimeout, txState, errType) // err indicates transaction not found so may already be removed if err == nil { incrementErrorMetrics(errType, c.chainID) @@ -652,7 +654,7 @@ func incrementErrorMetrics(errType TxErrType, chainID string) { promSolTxmErrorTxs.WithLabelValues(chainID).Inc() } -func (c *pendingTxContextWithProm) GetTxState(id string) (TxState, error) { +func (c *pendingTxContextWithProm) GetTxState(id string) (utils.TxState, error) { return c.pendingTx.GetTxState(id) } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index e7b7fc51e..759f54ca3 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) func TestPendingTxContext_add_remove_multiple(t *testing.T) { @@ -90,7 +92,7 @@ func TestPendingTxContext_new(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Broadcasted - require.Equal(t, Broadcasted, tx.state) + require.Equal(t, utils.Broadcasted, tx.state) // Check it does not exist in confirmed map _, exists = txs.confirmedTxs[msg.id] @@ -222,7 +224,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Processed - require.Equal(t, Processed, tx.state) + require.Equal(t, utils.Processed, tx.state) // Check it does not exist in confirmed map _, exists = txs.confirmedTxs[msg.id] @@ -293,7 +295,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -361,7 +363,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, sig, tx.signatures[0]) // Check status is Confirmed - require.Equal(t, Confirmed, tx.state) + require.Equal(t, utils.Confirmed, tx.state) // Check it does not exist in finalized map _, exists = txs.finalizedErroredTxs[msg.id] @@ -405,7 +407,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -475,7 +477,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Finalized, tx.state) + require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig1] @@ -525,7 +527,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Finalized, tx.state) + require.Equal(t, utils.Finalized, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig1] @@ -583,7 +585,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -608,7 +610,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.NoError(t, err) // Transition to errored state - id, err := txs.OnError(sig, retentionTimeout, Errored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -625,7 +627,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -646,7 +648,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, retentionTimeout, Errored, 0) + id, err = txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -663,7 +665,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Finalized - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -679,7 +681,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.NoError(t, err) // Transition to fatally errored state - id, err := txs.OnError(sig, retentionTimeout, FatallyErrored, 0) + id, err := txs.OnError(sig, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -692,7 +694,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, FatallyErrored, tx.state) + require.Equal(t, utils.FatallyErrored, tx.state) // Check sigs do no exist in signature map _, exists = txs.sigToID[sig] @@ -713,7 +715,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition to errored state - id, err = txs.OnError(sig, 0*time.Second, Errored, 0) + id, err = txs.OnError(sig, 0*time.Second, utils.Errored, 0) require.NoError(t, err) require.Equal(t, msg.id, id) @@ -748,7 +750,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Transition back to confirmed state - id, err = txs.OnError(sig, retentionTimeout, Errored, 0) + id, err = txs.OnError(sig, retentionTimeout, utils.Errored, 0) require.Error(t, err) require.Equal(t, "", id) }) @@ -764,7 +766,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} // Transition to errored state - err := txs.OnPrebroadcastError(msg.id, retentionTimeout, Errored, 0) + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.Errored, 0) require.NoError(t, err) // Check it exists in errored map @@ -772,7 +774,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, Errored, tx.state) + require.Equal(t, utils.Errored, tx.state) }) t.Run("successfully adds transaction with fatally errored state", func(t *testing.T) { @@ -780,7 +782,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { msg := pendingTx{id: uuid.NewString()} // Transition to fatally errored state - err := txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + err := txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) // Check it exists in errored map @@ -788,7 +790,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.True(t, exists) // Check status is Errored - require.Equal(t, FatallyErrored, tx.state) + require.Equal(t, utils.FatallyErrored, tx.state) }) t.Run("fails to add transaction to errored map if id exists in another map already", func(t *testing.T) { @@ -801,7 +803,7 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { require.NoError(t, err) // Transition to errored state - err = txs.OnPrebroadcastError(msg.id, retentionTimeout, FatallyErrored, 0) + err = txs.OnPrebroadcastError(msg.id, retentionTimeout, utils.FatallyErrored, 0) require.ErrorIs(t, err, ErrIDAlreadyExists) }) @@ -809,11 +811,11 @@ func TestPendingTxContext_on_prebroadcast_error(t *testing.T) { txID := uuid.NewString() // Transition to errored state - err := txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + err := txs.OnPrebroadcastError(txID, retentionTimeout, utils.Errored, 0) require.NoError(t, err) // Transition back to errored state - err = txs.OnPrebroadcastError(txID, retentionTimeout, Errored, 0) + err = txs.OnPrebroadcastError(txID, retentionTimeout, utils.Errored, 0) require.ErrorIs(t, err, ErrAlreadyInExpectedState) }) } @@ -867,7 +869,7 @@ func TestPendingTxContext_remove(t *testing.T) { erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) @@ -1062,7 +1064,7 @@ func TestGetTxState(t *testing.T) { err := txs.New(broadcastedMsg, broadcastedSig, cancel) require.NoError(t, err) - var state TxState + var state utils.TxState // Create new processed transaction processedMsg := pendingTx{id: uuid.NewString()} err = txs.New(processedMsg, processedSig, cancel) @@ -1073,7 +1075,7 @@ func TestGetTxState(t *testing.T) { // Check Processed state is returned state, err = txs.GetTxState(processedMsg.id) require.NoError(t, err) - require.Equal(t, Processed, state) + require.Equal(t, utils.Processed, state) // Create new confirmed transaction confirmedMsg := pendingTx{id: uuid.NewString()} @@ -1085,7 +1087,7 @@ func TestGetTxState(t *testing.T) { // Check Confirmed state is returned state, err = txs.GetTxState(confirmedMsg.id) require.NoError(t, err) - require.Equal(t, Confirmed, state) + require.Equal(t, utils.Confirmed, state) // Create new finalized transaction finalizedMsg := pendingTx{id: uuid.NewString()} @@ -1097,36 +1099,36 @@ func TestGetTxState(t *testing.T) { // Check Finalized state is returned state, err = txs.GetTxState(finalizedMsg.id) require.NoError(t, err) - require.Equal(t, Finalized, state) + require.Equal(t, utils.Finalized, state) // Create new errored transaction erroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) + id, err = txs.OnError(erroredSig, retentionTimeout, utils.Errored, 0) require.NoError(t, err) require.Equal(t, erroredMsg.id, id) // Check Errored state is returned state, err = txs.GetTxState(erroredMsg.id) require.NoError(t, err) - require.Equal(t, Errored, state) + require.Equal(t, utils.Errored, state) // Create new fatally errored transaction fatallyErroredMsg := pendingTx{id: uuid.NewString()} err = txs.New(fatallyErroredMsg, fatallyErroredSig, cancel) require.NoError(t, err) - id, err = txs.OnError(fatallyErroredSig, retentionTimeout, FatallyErrored, 0) + id, err = txs.OnError(fatallyErroredSig, retentionTimeout, utils.FatallyErrored, 0) require.NoError(t, err) require.Equal(t, fatallyErroredMsg.id, id) // Check Errored state is returned state, err = txs.GetTxState(fatallyErroredMsg.id) require.NoError(t, err) - require.Equal(t, FatallyErrored, state) + require.Equal(t, utils.FatallyErrored, state) // Check NotFound state is returned if unknown id provided state, err = txs.GetTxState("unknown id") require.Error(t, err) - require.Equal(t, NotFound, state) + require.Equal(t, utils.NotFound, state) } func randomSignature(t *testing.T) solana.Signature { diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 342f54dce..0d66aba63 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -25,6 +25,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" ) const ( @@ -36,8 +37,6 @@ const ( MaxComputeUnitLimit = 1_400_000 // max compute unit limit a transaction can have ) -var _ services.Service = (*Txm)(nil) - type SimpleKeystore interface { Sign(ctx context.Context, account string, data []byte) (signature []byte, err error) Accounts(ctx context.Context) (accounts []string, err error) @@ -45,6 +44,14 @@ type SimpleKeystore interface { var _ loop.Keystore = (SimpleKeystore)(nil) +type TxManager interface { + services.Service + Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error + GetTransactionStatus(ctx context.Context, transactionID string) (commontypes.TransactionStatus, error) +} + +var _ TxManager = (*Txm)(nil) + // Txm manages transactions for the solana blockchain. // simple implementation with no persistently stored txs type Txm struct { @@ -64,19 +71,6 @@ type Txm struct { sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error) } -type TxConfig struct { - Timeout time.Duration // transaction broadcast timeout - - // compute unit price config - FeeBumpPeriod time.Duration // how often to bump fee - BaseComputeUnitPrice uint64 // starting price - ComputeUnitPriceMin uint64 // min price - ComputeUnitPriceMax uint64 // max price - - EstimateComputeUnitLimit bool // enable compute limit estimations using simulation - ComputeUnitLimit uint32 // compute unit limit -} - // NewTxm creates a txm. Uses simulation so should only be used to send txes to trusted contracts i.e. OCR. func NewTxm(chainID string, client internal.Loader[client.ReaderWriter], sendTx func(ctx context.Context, tx *solanaGo.Transaction) (solanaGo.Signature, error), @@ -240,7 +234,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { cancel() // cancel context when exiting early - stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) + stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailReject) return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } @@ -252,7 +246,7 @@ func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Tran } // used for tracking rebroadcasting only in SendWithRetry - var sigs signatureList + var sigs txmutils.SignatureList sigs.Allocate() if initSetErr := sigs.Set(0, sig); initSetErr != nil { return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature in signature list: %w", initSetErr) @@ -402,7 +396,7 @@ func (txm *Txm) confirm() { // process signatures processSigs := func(s []solanaGo.Signature, res []*rpc.SignatureStatusesResult) { // sort signatures and results process successful first - s, res, err := SortSignaturesAndResults(s, res) + s, res, err := txmutils.SortSignaturesAndResults(s, res) if err != nil { txm.lggr.Errorw("sorting error", "error", err) return @@ -418,7 +412,7 @@ func (txm *Txm) confirm() { // check confirm timeout exceeded if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { @@ -454,7 +448,7 @@ func (txm *Txm) confirm() { } // check confirm timeout exceeded if TxConfirmTimeout set if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txmutils.Errored, TxFailDrop) if err != nil { txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) } else { @@ -580,7 +574,7 @@ func (txm *Txm) reap() { } // Enqueue enqueues a msg destined for the solana chain. -func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...SetTxConfig) error { +func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...txmutils.SetTxConfig) error { if err := txm.Ready(); err != nil { return fmt.Errorf("error in soltxm.Enqueue: %w", err) } @@ -650,15 +644,15 @@ func (txm *Txm) GetTransactionStatus(ctx context.Context, transactionID string) } switch state { - case Broadcasted: + case txmutils.Broadcasted: return commontypes.Pending, nil - case Processed, Confirmed: + case txmutils.Processed, txmutils.Confirmed: return commontypes.Unconfirmed, nil - case Finalized: + case txmutils.Finalized: return commontypes.Finalized, nil - case Errored: + case txmutils.Errored: return commontypes.Failed, nil - case FatallyErrored: + case txmutils.FatallyErrored: return commontypes.Fatal, nil default: return commontypes.Unknown, fmt.Errorf("found unknown transaction state: %s", state.String()) @@ -746,7 +740,7 @@ func (txm *Txm) simulateTx(ctx context.Context, tx *solanaGo.Transaction) (res * } // processError parses and handles relevant errors found in simulation results -func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulation bool) (txState TxState, errType TxErrType) { +func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulation bool) (txState txmutils.TxState, errType TxErrType) { if resErr != nil { // handle various errors // https://github.com/solana-labs/solana/blob/master/sdk/src/transaction/error.rs @@ -772,11 +766,11 @@ func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulat if simulation { return txState, NoFailure } - return Errored, errType + return txmutils.Errored, errType // transaction will encounter execution error/revert case strings.Contains(errStr, "InstructionError"): txm.lggr.Debugw("InstructionError", logValues...) - return Errored, errType + return txmutils.Errored, errType // transaction is already processed in the chain case strings.Contains(errStr, "AlreadyProcessed"): txm.lggr.Debugw("AlreadyProcessed", logValues...) @@ -785,7 +779,7 @@ func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulat if simulation { return txState, NoFailure } - return Errored, errType + return txmutils.Errored, errType // unrecognized errors (indicates more concerning failures) default: // if simulating, return TxFailSimOther if error unknown @@ -793,7 +787,7 @@ func (txm *Txm) processError(sig solanaGo.Signature, resErr interface{}, simulat errType = TxFailSimOther } txm.lggr.Errorw("unrecognized error", logValues...) - return Errored, errType + return txmutils.Errored, errType } } return @@ -815,8 +809,8 @@ func (txm *Txm) Name() string { return txm.lggr.Name() } func (txm *Txm) HealthReport() map[string]error { return map[string]error{txm.Name(): txm.Healthy()} } -func (txm *Txm) defaultTxConfig() TxConfig { - return TxConfig{ +func (txm *Txm) defaultTxConfig() txmutils.TxConfig { + return txmutils.TxConfig{ Timeout: txm.cfg.TxRetryTimeout(), FeeBumpPeriod: txm.cfg.FeeBumpPeriod(), BaseComputeUnitPrice: txm.fee.BaseComputeUnitPrice(), diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 418bdbec1..21df58833 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -24,6 +24,7 @@ import ( "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/fees" keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" + txmutils "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/utils" relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -676,7 +677,7 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, txmutils.SetFeeBumpPeriod(0))) wg.Wait() // no transactions stored inflight txs list @@ -728,7 +729,7 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping and disabled compute unit limit testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, txmutils.SetFeeBumpPeriod(0), txmutils.SetComputeUnitLimit(0))) wg.Wait() // no transactions stored inflight txs list diff --git a/pkg/solana/txm/utils.go b/pkg/solana/txm/utils/utils.go similarity index 82% rename from pkg/solana/txm/utils.go rename to pkg/solana/txm/utils/utils.go index fef260e3d..7f3ffb9e2 100644 --- a/pkg/solana/txm/utils.go +++ b/pkg/solana/txm/utils/utils.go @@ -1,4 +1,4 @@ -package txm +package utils import ( "errors" @@ -111,39 +111,39 @@ func convertStatus(res *rpc.SignatureStatusesResult) TxState { return NotFound } -type signatureList struct { +type SignatureList struct { sigs []solana.Signature lock sync.RWMutex wg []*sync.WaitGroup } // internal function that should be called using the proper lock -func (s *signatureList) get(index int) (sig solana.Signature, err error) { +func (s *SignatureList) get(index int) (sig solana.Signature, err error) { if index >= len(s.sigs) { return sig, errors.New("invalid index") } return s.sigs[index], nil } -func (s *signatureList) Get(index int) (sig solana.Signature, err error) { +func (s *SignatureList) Get(index int) (sig solana.Signature, err error) { s.lock.RLock() defer s.lock.RUnlock() return s.get(index) } -func (s *signatureList) List() []solana.Signature { +func (s *SignatureList) List() []solana.Signature { s.lock.RLock() defer s.lock.RUnlock() return s.sigs } -func (s *signatureList) Length() int { +func (s *SignatureList) Length() int { s.lock.RLock() defer s.lock.RUnlock() return len(s.sigs) } -func (s *signatureList) Allocate() (index int) { +func (s *SignatureList) Allocate() (index int) { s.lock.Lock() defer s.lock.Unlock() @@ -156,7 +156,7 @@ func (s *signatureList) Allocate() (index int) { return len(s.sigs) - 1 } -func (s *signatureList) Set(index int, sig solana.Signature) error { +func (s *SignatureList) Set(index int, sig solana.Signature) error { s.lock.Lock() defer s.lock.Unlock() @@ -174,7 +174,7 @@ func (s *signatureList) Set(index int, sig solana.Signature) error { return nil } -func (s *signatureList) Wait(index int) { +func (s *SignatureList) Wait(index int) { wg := &sync.WaitGroup{} s.lock.RLock() if index < len(s.wg) { @@ -185,6 +185,19 @@ func (s *signatureList) Wait(index int) { wg.Wait() } +type TxConfig struct { + Timeout time.Duration // transaction broadcast timeout + + // compute unit price config + FeeBumpPeriod time.Duration // how often to bump fee + BaseComputeUnitPrice uint64 // starting price + ComputeUnitPriceMin uint64 // min price + ComputeUnitPriceMax uint64 // max price + + EstimateComputeUnitLimit bool // enable compute limit estimations using simulation + ComputeUnitLimit uint32 // compute unit limit +} + type SetTxConfig func(*TxConfig) func SetTimeout(t time.Duration) SetTxConfig { diff --git a/pkg/solana/txm/utils_test.go b/pkg/solana/txm/utils/utils_test.go similarity index 98% rename from pkg/solana/txm/utils_test.go rename to pkg/solana/txm/utils/utils_test.go index f4ac868ff..676f04202 100644 --- a/pkg/solana/txm/utils_test.go +++ b/pkg/solana/txm/utils/utils_test.go @@ -1,4 +1,4 @@ -package txm +package utils import ( "sync" @@ -42,7 +42,7 @@ func TestSortSignaturesAndResults(t *testing.T) { } func TestSignatureList_AllocateWaitSet(t *testing.T) { - sigs := signatureList{} + sigs := SignatureList{} assert.Equal(t, 0, sigs.Length()) // can't set without pre-allocating diff --git a/pkg/solana/utils.go b/pkg/solana/utils.go deleted file mode 100644 index a4387aea8..000000000 --- a/pkg/solana/utils.go +++ /dev/null @@ -1,5 +0,0 @@ -package solana - -import "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" - -func LamportsToSol(lamports uint64) float64 { return internal.LamportsToSol(lamports) } diff --git a/pkg/solana/utils/utils.go b/pkg/solana/utils/utils.go new file mode 100644 index 000000000..3353d40b3 --- /dev/null +++ b/pkg/solana/utils/utils.go @@ -0,0 +1,217 @@ +package utils + +import ( + "context" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" + "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/internal" +) + +var ( + _, b, _, _ = runtime.Caller(0) + // ProjectRoot Root folder of this project + ProjectRoot = filepath.Join(filepath.Dir(b), "/../../..") + // ContractsDir path to our contracts + ContractsDir = filepath.Join(ProjectRoot, "contracts", "target", "deploy") + PathToAnchorConfig = filepath.Join(ProjectRoot, "contracts", "Anchor.toml") +) + +func LamportsToSol(lamports uint64) float64 { return internal.LamportsToSol(lamports) } + +// TxModifier is a dynamic function used to flexibly add components to a transaction such as additional signers, and compute budget parameters +type TxModifier func(tx *solana.Transaction, signers map[solana.PublicKey]solana.PrivateKey) error + +func SendAndConfirm(ctx context.Context, t *testing.T, rpcClient *rpc.Client, instructions []solana.Instruction, + signer solana.PrivateKey, commitment rpc.CommitmentType, opts ...TxModifier) *rpc.GetTransactionResult { + txres := sendTransaction(ctx, rpcClient, t, instructions, signer, commitment, false, opts...) // do not skipPreflight when expected to pass, preflight can help debug + + require.NotNil(t, txres.Meta) + require.Nil(t, txres.Meta.Err, fmt.Sprintf("tx failed with: %+v", txres.Meta)) // tx should not err, print meta if it does (contains logs) + return txres +} + +func sendTransaction(ctx context.Context, rpcClient *rpc.Client, t *testing.T, instructions []solana.Instruction, + signerAndPayer solana.PrivateKey, commitment rpc.CommitmentType, skipPreflight bool, opts ...TxModifier) *rpc.GetTransactionResult { + hashRes, err := rpcClient.GetLatestBlockhash(ctx, rpc.CommitmentFinalized) + require.NoError(t, err) + + tx, err := solana.NewTransaction( + instructions, + hashRes.Value.Blockhash, + solana.TransactionPayer(signerAndPayer.PublicKey()), + ) + require.NoError(t, err) + + // build signers map + signers := map[solana.PublicKey]solana.PrivateKey{} + signers[signerAndPayer.PublicKey()] = signerAndPayer + + // set options before signing transaction + for _, o := range opts { + require.NoError(t, o(tx, signers)) + } + + _, err = tx.Sign(func(pub solana.PublicKey) *solana.PrivateKey { + priv, ok := signers[pub] + require.True(t, ok, fmt.Sprintf("Missing signer private key for %s", pub)) + return &priv + }) + require.NoError(t, err) + + txsig, err := rpcClient.SendTransactionWithOpts(ctx, tx, rpc.TransactionOpts{SkipPreflight: skipPreflight, PreflightCommitment: rpc.CommitmentProcessed}) + require.NoError(t, err) + + var txStatus rpc.ConfirmationStatusType + count := 0 + for txStatus != rpc.ConfirmationStatusType(commitment) && txStatus != rpc.ConfirmationStatusFinalized { + count++ + statusRes, sigErr := rpcClient.GetSignatureStatuses(ctx, true, txsig) + require.NoError(t, sigErr) + if statusRes != nil && len(statusRes.Value) > 0 && statusRes.Value[0] != nil { + txStatus = statusRes.Value[0].ConfirmationStatus + } + time.Sleep(100 * time.Millisecond) + if count > 500 { + require.NoError(t, fmt.Errorf("unable to find transaction within timeout")) + } + } + + txres, err := rpcClient.GetTransaction(ctx, txsig, &rpc.GetTransactionOpts{ + Commitment: commitment, + }) + require.NoError(t, err) + return txres +} + +var ( + AddressLookupTableProgram = solana.MustPublicKeyFromBase58("AddressLookupTab1e1111111111111111111111111") +) + +const ( + InstructionCreateLookupTable uint32 = iota + InstructionFreezeLookupTable + InstructionExtendLookupTable + InstructionDeactiveLookupTable + InstructionCloseLookupTable +) + +func NewCreateLookupTableInstruction( + authority, funder solana.PublicKey, + slot uint64, +) (solana.PublicKey, solana.Instruction, error) { + // https://github.com/solana-labs/solana-web3.js/blob/c1c98715b0c7900ce37c59bffd2056fa0037213d/src/programs/address-lookup-table/index.ts#L274 + slotLE := make([]byte, 8) + binary.LittleEndian.PutUint64(slotLE, slot) + account, bumpSeed, err := solana.FindProgramAddress([][]byte{authority.Bytes(), slotLE}, AddressLookupTableProgram) + if err != nil { + return solana.PublicKey{}, nil, err + } + + data := binary.LittleEndian.AppendUint32([]byte{}, InstructionCreateLookupTable) + data = binary.LittleEndian.AppendUint64(data, slot) + data = append(data, bumpSeed) + return account, solana.NewInstruction( + AddressLookupTableProgram, + solana.AccountMetaSlice{ + solana.Meta(account).WRITE(), + solana.Meta(authority).SIGNER(), + solana.Meta(funder).SIGNER().WRITE(), + solana.Meta(solana.SystemProgramID), + }, + data, + ), nil +} + +func NewExtendLookupTableInstruction( + table, authority, funder solana.PublicKey, + accounts []solana.PublicKey, +) solana.Instruction { + // https://github.com/solana-labs/solana-web3.js/blob/c1c98715b0c7900ce37c59bffd2056fa0037213d/src/programs/address-lookup-table/index.ts#L113 + + data := binary.LittleEndian.AppendUint32([]byte{}, InstructionExtendLookupTable) + data = binary.LittleEndian.AppendUint64(data, uint64(len(accounts))) // note: this is usually u32 + 8 byte buffer + for _, a := range accounts { + data = append(data, a.Bytes()...) + } + + return solana.NewInstruction( + AddressLookupTableProgram, + solana.AccountMetaSlice{ + solana.Meta(table).WRITE(), + solana.Meta(authority).SIGNER(), + solana.Meta(funder).SIGNER().WRITE(), + solana.Meta(solana.SystemProgramID), + }, + data, + ) +} + +func FundAccounts(ctx context.Context, accounts []solana.PrivateKey, solanaGoClient *rpc.Client, t *testing.T) { + sigs := []solana.Signature{} + for _, v := range accounts { + sig, err := solanaGoClient.RequestAirdrop(ctx, v.PublicKey(), 1000*solana.LAMPORTS_PER_SOL, rpc.CommitmentFinalized) + require.NoError(t, err) + sigs = append(sigs, sig) + } + + // wait for confirmation so later transactions don't fail + remaining := len(sigs) + count := 0 + for remaining > 0 { + count++ + statusRes, sigErr := solanaGoClient.GetSignatureStatuses(ctx, true, sigs...) + require.NoError(t, sigErr) + require.NotNil(t, statusRes) + require.NotNil(t, statusRes.Value) + + unconfirmedTxCount := 0 + for _, res := range statusRes.Value { + if res == nil || res.ConfirmationStatus == rpc.ConfirmationStatusProcessed || res.ConfirmationStatus == rpc.ConfirmationStatusConfirmed { + unconfirmedTxCount++ + } + } + remaining = unconfirmedTxCount + fmt.Printf("Waiting for finalized funding on %d addresses\n", remaining) + + time.Sleep(500 * time.Millisecond) + if count > 60 { + require.NoError(t, fmt.Errorf("unable to find transaction within timeout")) + } + } +} + +func DeployAllPrograms(t *testing.T, pathToAnchorConfig string, admin solana.PrivateKey) *rpc.Client { + return rpc.New(SetupTestValidatorWithAnchorPrograms(t, pathToAnchorConfig, admin.PublicKey().String())) +} + +func SetupTestValidatorWithAnchorPrograms(t *testing.T, pathToAnchorConfig string, upgradeAuthority string) string { + anchorData := struct { + Programs struct { + Localnet map[string]string + } + }{} + + // upload programs to validator + anchorBytes, err := os.ReadFile(pathToAnchorConfig) + require.NoError(t, err) + require.NoError(t, toml.Unmarshal(anchorBytes, &anchorData)) + + flags := []string{} + for k, v := range anchorData.Programs.Localnet { + flags = append(flags, "--upgradeable-program", v, filepath.Join(ContractsDir, k+".so"), upgradeAuthority) + } + url, _ := client.SetupLocalSolNodeWithFlags(t, flags...) + return url +} diff --git a/pkg/solana/utils_test.go b/pkg/solana/utils/utils_test.go similarity index 71% rename from pkg/solana/utils_test.go rename to pkg/solana/utils/utils_test.go index 67efc932b..0f41f80c9 100644 --- a/pkg/solana/utils_test.go +++ b/pkg/solana/utils/utils_test.go @@ -1,9 +1,11 @@ -package solana +package utils_test import ( "testing" "github.com/stretchr/testify/assert" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/utils" ) func TestLamportsToSol(t *testing.T) { @@ -19,7 +21,7 @@ func TestLamportsToSol(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.out, LamportsToSol(test.in)) + assert.Equal(t, test.out, utils.LamportsToSol(test.in)) }) } }