diff --git a/programs/protocol-contracts-solana/src/lib.rs b/programs/protocol-contracts-solana/src/lib.rs index 0a56b58..e8bb085 100644 --- a/programs/protocol-contracts-solana/src/lib.rs +++ b/programs/protocol-contracts-solana/src/lib.rs @@ -1,6 +1,7 @@ use anchor_lang::prelude::*; use anchor_lang::system_program; -use anchor_spl::token::{transfer, Token, TokenAccount}; +use anchor_spl::associated_token::get_associated_token_address; +use anchor_spl::token::{transfer, transfer_checked, Mint, Token, TokenAccount}; use solana_program::keccak::hash; use solana_program::secp256k1_recover::secp256k1_recover; use std::mem::size_of; @@ -25,12 +26,15 @@ pub enum Errors { MemoLengthTooShort, #[msg("DepositPaused")] DepositPaused, + #[msg("SPLAtaAndMintAddressMismatch")] + SPLAtaAndMintAddressMismatch, } declare_id!("ZETAjseVjuFsxdRxo6MmTCvqFwb3ZHUx56Co3vCmGis"); #[program] pub mod gateway { + use super::*; pub fn initialize( @@ -235,6 +239,7 @@ pub mod gateway { // concatenated_buffer vec. pub fn withdraw_spl_token( ctx: Context, + decimals: u8, amount: u64, signature: [u8; 64], recovery_id: u8, @@ -253,7 +258,7 @@ pub mod gateway { concatenated_buffer.extend_from_slice(&pda.chain_id.to_be_bytes()); concatenated_buffer.extend_from_slice(&nonce.to_be_bytes()); concatenated_buffer.extend_from_slice(&amount.to_be_bytes()); - concatenated_buffer.extend_from_slice(&ctx.accounts.from.key().to_bytes()); + concatenated_buffer.extend_from_slice(&ctx.accounts.mint_account.key().to_bytes()); concatenated_buffer.extend_from_slice(&ctx.accounts.to.key().to_bytes()); require!( message_hash == hash(&concatenated_buffer[..]).to_bytes(), @@ -267,13 +272,23 @@ pub mod gateway { return err!(Errors::TSSAuthenticationFailed); } + // associated token address (ATA) of the program PDA + // the PDA is the "wallet" (owner) of the token account + // the token is stored in ATA account owned by the PDA + let pda_ata = get_associated_token_address(&pda.key(), &ctx.accounts.mint_account.key()); + require!( + pda_ata == ctx.accounts.pda_ata.to_account_info().key(), + Errors::SPLAtaAndMintAddressMismatch + ); + let token = &ctx.accounts.token_program; let signer_seeds: &[&[&[u8]]] = &[&[b"meta", &[ctx.bumps.pda]]]; let xfer_ctx = CpiContext::new_with_signer( token.to_account_info(), - anchor_spl::token::Transfer { - from: ctx.accounts.from.to_account_info(), + anchor_spl::token::TransferChecked { + from: ctx.accounts.pda_ata.to_account_info(), + mint: ctx.accounts.mint_account.to_account_info(), to: ctx.accounts.to.to_account_info(), authority: pda.to_account_info(), }, @@ -282,7 +297,7 @@ pub mod gateway { pda.nonce += 1; - transfer(xfer_ctx, amount)?; + transfer_checked(xfer_ctx, amount, decimals)?; msg!("withdraw spl token successfully"); Ok(()) @@ -364,7 +379,10 @@ pub struct WithdrawSPLToken<'info> { pub pda: Account<'info, Pda>, #[account(mut)] - pub from: Account<'info, TokenAccount>, + pub pda_ata: Account<'info, TokenAccount>, // associated token address of PDA + + #[account()] + pub mint_account: Account<'info, Mint>, #[account(mut)] pub to: Account<'info, TokenAccount>, diff --git a/tests/protocol-contracts-solana.ts b/tests/protocol-contracts-solana.ts index 96e58f8..9511e50 100644 --- a/tests/protocol-contracts-solana.ts +++ b/tests/protocol-contracts-solana.ts @@ -17,6 +17,8 @@ const ec = new EC('secp256k1'); // read private key from hex dump const keyPair = ec.keyFromPrivate('5b81cdf52ba0766983acf8dd0072904733d92afe4dd3499e83e879b43ccb73e8'); +const usdcDecimals = 6; + describe("some tests", () => { // Configure the client to use the local cluster. anchor.setProvider(anchor.AnchorProvider.env()); @@ -24,6 +26,8 @@ describe("some tests", () => { const gatewayProgram = anchor.workspace.Gateway as Program; const wallet = anchor.workspace.Gateway.provider.wallet.payer; const mint = anchor.web3.Keypair.generate(); + const mint_fake = anchor.web3.Keypair.generate(); // for testing purpose + let tokenAccount: spl.Account; let wallet_ata: anchor.web3.PublicKey; let pdaAccount: anchor.web3.PublicKey; @@ -75,7 +79,7 @@ describe("some tests", () => { // now deploying a fake USDC SPL Token // 1. create a mint account const mintRent = await spl.getMinimumBalanceForRentExemptMint(conn); - const tokenTransaction = new anchor.web3.Transaction(); + let tokenTransaction = new anchor.web3.Transaction(); tokenTransaction.add( anchor.web3.SystemProgram.createAccount({ fromPubkey: wallet.publicKey, @@ -86,7 +90,7 @@ describe("some tests", () => { }), spl.createInitializeMintInstruction( mint.publicKey, - 6, + usdcDecimals, wallet.publicKey, null, ) @@ -122,6 +126,26 @@ describe("some tests", () => { wallet.publicKey, ); console.log(`wallet_ata: ${wallet_ata.toString()}`); + + // create a fake USDC token account + tokenTransaction = new anchor.web3.Transaction(); + tokenTransaction.add( + anchor.web3.SystemProgram.createAccount({ + fromPubkey: wallet.publicKey, + newAccountPubkey: mint_fake.publicKey, + lamports: mintRent, + space: spl.MINT_SIZE, + programId: spl.TOKEN_PROGRAM_ID + }), + spl.createInitializeMintInstruction( + mint_fake.publicKey, + usdcDecimals, + wallet.publicKey, + null, + ) + ); + await anchor.web3.sendAndConfirmTransaction(conn, tokenTransaction, [wallet, mint_fake]); + console.log("fake mint account created!", mint_fake.publicKey.toString()); }) it("Deposit 1_000_000 USDC to Gateway", async () => { @@ -208,7 +232,7 @@ describe("some tests", () => { chain_id_bn.toArrayLike(Buffer, 'be', 8), nonce.toArrayLike(Buffer, 'be', 8), amount.toArrayLike(Buffer, 'be', 8), - pda_ata.address.toBuffer(), + mint.publicKey.toBuffer(), wallet_ata.toBuffer(), ]); const message_hash = keccak256(buffer); @@ -219,9 +243,10 @@ describe("some tests", () => { s.toArrayLike(Buffer, 'be', 32), ]); - await gatewayProgram.methods.withdrawSplToken(amount, Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce) + await gatewayProgram.methods.withdrawSplToken(usdcDecimals,amount, Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce) .accounts({ - from: pda_ata.address, + pdaAta: pda_ata.address, + mintAccount: mint.publicKey, to: wallet_ata, }).rpc(); @@ -230,9 +255,10 @@ describe("some tests", () => { try { - (await gatewayProgram.methods.withdrawSplToken(new anchor.BN(500_000), Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce) + (await gatewayProgram.methods.withdrawSplToken(usdcDecimals,new anchor.BN(500_000), Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce) .accounts({ - from: pda_ata.address, + pdaAta: pda_ata.address, + mintAccount: mint.publicKey, to: wallet_ata, }).rpc()); throw new Error("Expected error not thrown"); // This line will make the test fail if no error is thrown @@ -244,6 +270,40 @@ describe("some tests", () => { expect(account4.amount).to.be.eq(2_500_000n); } + + try { + const nonce2 = nonce.addn(1) + const buffer = Buffer.concat([ + Buffer.from("withdraw_spl_token","utf-8"), + chain_id_bn.toArrayLike(Buffer, 'be', 8), + nonce2.toArrayLike(Buffer, 'be', 8), + amount.toArrayLike(Buffer, 'be', 8), + mint_fake.publicKey.toBuffer(), + wallet_ata.toBuffer(), + ]); + const message_hash = keccak256(buffer); + const signature = keyPair.sign(message_hash, 'hex'); + const { r, s, recoveryParam } = signature; + const signatureBuffer = Buffer.concat([ + r.toArrayLike(Buffer, 'be', 32), + s.toArrayLike(Buffer, 'be', 32), + ]); + await gatewayProgram.methods.withdrawSplToken(usdcDecimals,amount, Array.from(signatureBuffer), Number(recoveryParam), Array.from(message_hash), nonce2 ) + .accounts({ + pdaAta: pda_ata.address, + mintAccount: mint_fake.publicKey, + to: wallet_ata, + }).rpc(); + throw new Error("Expected error not thrown"); // This line will make the test fail if no error is thrown + } catch (err) { + expect(err).to.be.instanceof(anchor.AnchorError); + console.log("Error message: ", err.message); + expect(err.message).to.include("SPLAtaAndMintAddressMismatch"); + const account4 = await spl.getAccount(conn, pda_ata.address); + console.log("After 2nd withdraw: Account balance:", account4.amount.toString()); + expect(account4.amount).to.be.eq(2_500_000n); + } + }); it("deposit and withdraw 0.5 SOL from Gateway with ECDSA signature", async () => {