From bdfc9761cf7839efafdeae3d9c0bda00925f8152 Mon Sep 17 00:00:00 2001 From: brewmaster012 <88689859+brewmaster012@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:22:29 -0600 Subject: [PATCH] fix: use SPL ZRC20 withdraw fee to cover the ATA account creation rent fee (#67) --- programs/protocol-contracts-solana/src/lib.rs | 47 +++++-------- tests/protocol-contracts-solana.ts | 69 +++++++------------ 2 files changed, 39 insertions(+), 77 deletions(-) diff --git a/programs/protocol-contracts-solana/src/lib.rs b/programs/protocol-contracts-solana/src/lib.rs index 6551fab..1c5cc6b 100644 --- a/programs/protocol-contracts-solana/src/lib.rs +++ b/programs/protocol-contracts-solana/src/lib.rs @@ -167,10 +167,6 @@ pub mod gateway { Ok(()) } - pub fn initialize_rent_payer(_ctx: Context) -> Result<()> { - Ok(()) - } - // deposit SOL into this program and the `receiver` on ZetaChain zEVM // will get corresponding ZRC20 credit. // amount: amount of lamports (10^-9 SOL) to deposit @@ -386,6 +382,9 @@ pub mod gateway { Errors::SPLAtaAndMintAddressMismatch, ); + let cost_gas = 5000; // default gas cost in lamports + let cost_ata_create = &mut 0; // will be updated if ATA creation is needed + // test whether the recipient_ata is created or not; if not, create it let recipient_ata_account = ctx.accounts.recipient_ata.to_account_info(); if recipient_ata_account.lamports() == 0 @@ -397,8 +396,7 @@ pub mod gateway { recipient_ata_account.key(), ctx.accounts.recipient.key(), ); - let signer_info = &ctx.accounts.signer.to_account_info(); - let bal_before = signer_info.lamports(); + let bal_before = ctx.accounts.signer.lamports(); invoke( &create_associated_token_account( ctx.accounts.signer.to_account_info().key, @@ -419,22 +417,15 @@ pub mod gateway { .clone(), ], )?; - let bal_after = signer_info.lamports(); + let bal_after = ctx.accounts.signer.lamports(); + *cost_ata_create = bal_before - bal_after; msg!("Associated token account for recipient created!"); msg!( - "Refunding the rent paid by the signer {:?}", + "Refunding the rent ({:?} lamports) paid by the signer {:?}", + cost_ata_create, ctx.accounts.signer.to_account_info().key ); - - let rent_payer_info = ctx.accounts.rent_payer_pda.to_account_info(); - let cost = bal_before - bal_after; - rent_payer_info.sub_lamports(cost)?; - signer_info.add_lamports(cost)?; - msg!( - "Signer refunded the ATA account creation rent amount {:?} lamports", - cost, - ); } let xfer_ctx = CpiContext::new_with_signer( @@ -452,6 +443,14 @@ pub mod gateway { transfer_checked(xfer_ctx, amount, decimals)?; msg!("withdraw spl token successfully"); + // Note: this pda.sub_lamports() must be done here due to this issue https://github.com/solana-labs/solana/issues/9711 + // otherwise the previous CPI calls might fail with error: + // "sum of account balances before and after instruction do not match" + // Note2: to keep PDA from deficit, all SPL ZRC20 contracts needs to charge withdraw fee of + // at least 5000(gas)+2039280(rent) lamports. + let reimbursement = cost_gas + *cost_ata_create; + pda.sub_lamports(reimbursement)?; + ctx.accounts.signer.add_lamports(reimbursement)?; Ok(()) } @@ -589,9 +588,6 @@ pub struct WithdrawSPLToken<'info> { #[account(mut)] pub recipient_ata: AccountInfo<'info>, - #[account(mut, seeds = [b"rent-payer"], bump)] - pub rent_payer_pda: Account<'info, RentPayerPda>, - pub token_program: Program<'info, Token>, pub associated_token_program: Program<'info, AssociatedToken>, @@ -675,17 +671,6 @@ pub struct Unwhitelist<'info> { pub system_program: Program<'info, System>, } -#[derive(Accounts)] -pub struct InitializeRentPayer<'info> { - #[account(init, payer = authority, space = 8, seeds = [b"rent-payer"], bump)] - pub rent_payer_pda: Account<'info, RentPayerPda>, - - #[account(mut)] - pub authority: Signer<'info>, - - pub system_program: Program<'info, System>, -} - #[account] pub struct Pda { nonce: u64, // ensure that each signature can only be used once diff --git a/tests/protocol-contracts-solana.ts b/tests/protocol-contracts-solana.ts index 251bbb9..183e6da 100644 --- a/tests/protocol-contracts-solana.ts +++ b/tests/protocol-contracts-solana.ts @@ -126,12 +126,6 @@ describe("Gateway", () => { gatewayProgram.programId, ); - let rentPayerSeeds = [Buffer.from("rent-payer", "utf-8")]; - let [rentPayerPdaAccount] = anchor.web3.PublicKey.findProgramAddressSync( - rentPayerSeeds, - gatewayProgram.programId, - ); - it("Initializes the program", async () => { await gatewayProgram.methods.initialize(tssAddress, chain_id_bn).rpc(); @@ -143,18 +137,6 @@ describe("Gateway", () => { expect(err).to.be.not.null; } }); - it("initialize the rent payer PDA",async() => { - await gatewayProgram.methods.initializeRentPayer().rpc(); - let instr = web3.SystemProgram.transfer({ - fromPubkey: wallet.publicKey, - toPubkey: rentPayerPdaAccount, - lamports: 100000000, - }); - let tx = new web3.Transaction(); - tx.add(instr); - await web3.sendAndConfirmTransaction(conn,tx,[wallet]); - }); - it("Mint a SPL USDC token", async () => { // now deploying a fake USDC SPL Token @@ -343,33 +325,6 @@ describe("Gateway", () => { }); - it("withdraw SPL token to a non-existent account should succeed by creating it", async () => { - let seeds = [Buffer.from("rent-payer", "utf-8")]; - const [rentPayerPda] = anchor.web3.PublicKey.findProgramAddressSync( - seeds, - gatewayProgram.programId, - ); - let rentPayerPdaBal0 = await conn.getBalance(rentPayerPda); - let pda_ata = await spl.getAssociatedTokenAddress(mint.publicKey, pdaAccount, true); - const pdaAccountData = await gatewayProgram.account.pda.fetch(pdaAccount); - const hexAddr = bufferToHex(Buffer.from(pdaAccountData.tssAddress)); - const amount = new anchor.BN(500_000); - const nonce = pdaAccountData.nonce; - const wallet2 = anchor.web3.Keypair.generate(); - - const to = await spl.getAssociatedTokenAddress(mint.publicKey, wallet2.publicKey); - - let to_ata_bal = await conn.getBalance(to); - expect(to_ata_bal).to.be.eq(0); // the new ata account (owned by wallet2) should be non-existent; - const txsig = await withdrawSplToken(mint, usdcDecimals, amount, nonce, pda_ata, to, wallet2.publicKey, keyPair, gatewayProgram); - to_ata_bal = await conn.getBalance(to); - expect(to_ata_bal).to.be.gt(2_000_000); // the new ata account (owned by wallet2) should be created - - // rent_payer_pda should have reduced balance - - let rentPayerPdaBal1 = await conn.getBalance(rentPayerPda); - expect(rentPayerPdaBal0-rentPayerPdaBal1).to.be.eq(to_ata_bal); // rentPayer pays rent - }); it("fails to deposit if receiver is empty address", async() => { try { @@ -415,7 +370,29 @@ describe("Gateway", () => { let bal3 = await conn.getBalance(to); expect(bal3).to.be.gte(500_000_000); }) - + + it("withdraw SPL token to a non-existent account should succeed by creating it", async () => { + let rentPayerPdaBal0 = await conn.getBalance(pdaAccount); + let pda_ata = await spl.getAssociatedTokenAddress(mint.publicKey, pdaAccount, true); + const pdaAccountData = await gatewayProgram.account.pda.fetch(pdaAccount); + const amount = new anchor.BN(500_000); + const nonce = pdaAccountData.nonce; + const wallet2 = anchor.web3.Keypair.generate(); + const to = await spl.getAssociatedTokenAddress(mint.publicKey, wallet2.publicKey); + + let to_ata_bal = await conn.getBalance(to); + expect(to_ata_bal).to.be.eq(0); // the new ata account (owned by wallet2) should be non-existent; + const txsig = await withdrawSplToken(mint, usdcDecimals, amount, nonce, pda_ata, to, wallet2.publicKey, keyPair, gatewayProgram); + to_ata_bal = await conn.getBalance(to); + expect(to_ata_bal).to.be.gt(2_000_000); // the new ata account (owned by wallet2) should be created + + // pda should have reduced balance + let rentPayerPdaBal1 = await conn.getBalance(pdaAccount); + // expected reimbursement to be gas fee (5000 lamports) + ATA creation cost 2039280 lamports + expect(rentPayerPdaBal0-rentPayerPdaBal1).to.be.eq(to_ata_bal + 5000); // rentPayer pays rent + }); + + it("fails to deposit and call if receiver is empty address", async() => { try { await gatewayProgram.methods.depositAndCall(new anchor.BN(1_000_000_000), Array(20).fill(0), Buffer.from("hello", "utf-8")).accounts({}).rpc();