From b2803a84c9e2d6f276821b641f8cb8a841a2b54f Mon Sep 17 00:00:00 2001 From: Patrick Aljord Date: Fri, 16 Aug 2024 01:43:40 +0200 Subject: [PATCH] feat: add views for fees and rewards --- src/FoldCaptiveStaking.sol | 35 ++++++--- test/UnitTests.t.sol | 152 ++++++++++++++++++++++++++----------- 2 files changed, 130 insertions(+), 57 deletions(-) diff --git a/src/FoldCaptiveStaking.sol b/src/FoldCaptiveStaking.sol index 6af2cfb..ba95082 100644 --- a/src/FoldCaptiveStaking.sol +++ b/src/FoldCaptiveStaking.sol @@ -214,10 +214,7 @@ contract FoldCaptiveStaking is Owned(msg.sender) { function compound() public isInitialized { collectPositionFees(); - uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; - uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + (uint256 fee0Owed, uint256 fee1Owed) = owedFees(); INonfungiblePositionManager.IncreaseLiquidityParams memory params = INonfungiblePositionManager .IncreaseLiquidityParams({ @@ -243,14 +240,23 @@ contract FoldCaptiveStaking is Owned(msg.sender) { emit Compounded(msg.sender, liquidity, fee0Owed, fee1Owed); } + /// @notice User-specific function to view fees owed on the singular position + function owedFees() public view returns (uint256, uint256) { + uint256 fee0Owed = ((token0FeesPerLiquidity - + balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount) / + liquidityUnderManagement; + uint256 fee1Owed = ((token1FeesPerLiquidity - + balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount) / + liquidityUnderManagement; + + return (fee0Owed, fee1Owed); + } + /// @notice User-specific function to collect fees on the singular position function collectFees() public isInitialized { collectPositionFees(); - uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; - uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + (uint256 fee0Owed, uint256 fee1Owed) = owedFees(); token0.transfer(msg.sender, fee0Owed); token1.transfer(msg.sender, fee1Owed); @@ -261,10 +267,16 @@ contract FoldCaptiveStaking is Owned(msg.sender) { emit FeesCollected(msg.sender, fee0Owed, fee1Owed); } + /// @notice User-specific function to view rewards owed on the singular position + function owedRewards() public view returns (uint256) { + return + ((rewardsPerLiquidity - balances[msg.sender].rewardDebt) * + balances[msg.sender].amount) / liquidityUnderManagement; + } + /// @notice User-specific Rewards for Protocol Rewards function collectRewards() public isInitialized { - uint256 rewardsOwed = (rewardsPerLiquidity - balances[msg.sender].rewardDebt) * balances[msg.sender].amount - / liquidityUnderManagement; + uint256 rewardsOwed = owedRewards(); WETH9.transfer(msg.sender, rewardsOwed); @@ -373,7 +385,8 @@ contract FoldCaptiveStaking is Owned(msg.sender) { amount1Max: uint128(amount1) }); - (uint256 amount0Collected, uint256 amount1Collected) = positionManager.collect(collectParams); + (uint256 amount0Collected, uint256 amount1Collected) = positionManager + .collect(collectParams); if (amount0Collected != amount0 || amount1Collected != amount1) { revert WithdrawFailed(); diff --git a/test/UnitTests.t.sol b/test/UnitTests.t.sol index 3d8c969..bf942cc 100644 --- a/test/UnitTests.t.sol +++ b/test/UnitTests.t.sol @@ -4,7 +4,8 @@ import "test/BaseCaptiveTest.sol"; import "test/interfaces/ISwapRouter.sol"; contract UnitTests is BaseCaptiveTest { - ISwapRouter public router = ISwapRouter(0xE592427A0AEce92De3Edee1F18E0157C05861564); + ISwapRouter public router = + ISwapRouter(0xE592427A0AEce92De3Edee1F18E0157C05861564); /// @dev Ensure that balances and state variables are updated correctly. function testAddLiquidity() public { @@ -18,8 +19,12 @@ contract UnitTests is BaseCaptiveTest { fold.approve(address(foldCaptiveStaking), type(uint256).max); foldCaptiveStaking.deposit(1_000 ether, 1_000 ether, 0); - (uint128 amount, uint128 rewardDebt, uint128 token0FeeDebt, uint128 token1FeeDebt) = - foldCaptiveStaking.balances(User01); + ( + uint128 amount, + uint128 rewardDebt, + uint128 token0FeeDebt, + uint128 token1FeeDebt + ) = foldCaptiveStaking.balances(User01); assertGt(amount, 0); assertEq(rewardDebt, 0); @@ -31,13 +36,18 @@ contract UnitTests is BaseCaptiveTest { function testRemoveLiquidity() public { testAddLiquidity(); - (uint128 amount, uint128 rewardDebt, uint128 token0FeeDebt, uint128 token1FeeDebt) = - foldCaptiveStaking.balances(User01); + ( + uint128 amount, + uint128 rewardDebt, + uint128 token0FeeDebt, + uint128 token1FeeDebt + ) = foldCaptiveStaking.balances(User01); - (uint128 liq,,,) = foldCaptiveStaking.balances(User01); + (uint128 liq, , , ) = foldCaptiveStaking.balances(User01); foldCaptiveStaking.withdraw(liq / 2); - (amount, rewardDebt, token0FeeDebt, token1FeeDebt) = foldCaptiveStaking.balances(User01); + (amount, rewardDebt, token0FeeDebt, token1FeeDebt) = foldCaptiveStaking + .balances(User01); assertEq(amount, liq / 2); assertEq(rewardDebt, 0); @@ -46,11 +56,47 @@ contract UnitTests is BaseCaptiveTest { foldCaptiveStaking.withdraw(liq / 4); - (amount,,,) = foldCaptiveStaking.balances(User01); + (amount, , , ) = foldCaptiveStaking.balances(User01); assertEq(amount, liq / 4); } + /// @dev Ensure that owed fees are returned correctly. + function testOwedFees() public { + testAddLiquidity(); + uint256 owedReards = foldCaptiveStaking.owedRewards(); + (uint256 amount, uint256 rewardDebt, , ) = foldCaptiveStaking.balances( + User01 + ); + uint256 rewardOwedCheck = ((foldCaptiveStaking.rewardsPerLiquidity() - + rewardDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + assertEq(rewardOwedCheck, owedReards); + } + + /// @dev Ensure that owed fees are returned correctly. + function testOwedRewards() public { + testAddLiquidity(); + (uint256 fee0Owed, uint256 fee1Owed) = foldCaptiveStaking.owedFees(); + ( + uint256 amount, + , + uint256 token0FeeDebt, + uint256 token1FeeDebt + ) = foldCaptiveStaking.balances(User01); + uint256 fee0OwedCheck = ((foldCaptiveStaking.token0FeesPerLiquidity() - + token0FeeDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + uint256 fee1OwedCheck = ((foldCaptiveStaking.token1FeesPerLiquidity() - + token1FeeDebt) * amount) / + foldCaptiveStaking.liquidityUnderManagement(); + + assertEq(fee0OwedCheck, fee0Owed); + assertEq(fee1OwedCheck, fee1Owed); + } + /// @dev Ensure fees are accrued correctly and distributed proportionately. function testFeesAccrue() public { testAddLiquidity(); @@ -59,16 +105,17 @@ contract UnitTests is BaseCaptiveTest { weth.deposit{value: 10 ether}(); weth.approve(address(router), type(uint256).max); - ISwapRouter.ExactInputSingleParams memory params = ISwapRouter.ExactInputSingleParams({ - tokenIn: address(weth), - tokenOut: address(fold), - fee: 10_000, - recipient: msg.sender, - deadline: block.timestamp, - amountIn: 10 ether, - amountOutMinimum: 0, - sqrtPriceLimitX96: 0 - }); + ISwapRouter.ExactInputSingleParams memory params = ISwapRouter + .ExactInputSingleParams({ + tokenIn: address(weth), + tokenOut: address(fold), + fee: 10_000, + recipient: msg.sender, + deadline: block.timestamp, + amountIn: 10 ether, + amountOutMinimum: 0, + sqrtPriceLimitX96: 0 + }); // The call to `exactInputSingle` executes the swap. uint256 amountOut = router.exactInputSingle(params); @@ -108,16 +155,17 @@ contract UnitTests is BaseCaptiveTest { weth.deposit{value: 10 ether}(); weth.approve(address(router), type(uint256).max); - ISwapRouter.ExactInputSingleParams memory params = ISwapRouter.ExactInputSingleParams({ - tokenIn: address(weth), - tokenOut: address(fold), - fee: 10_000, - recipient: msg.sender, - deadline: block.timestamp, - amountIn: 10 ether, - amountOutMinimum: 0, - sqrtPriceLimitX96: 0 - }); + ISwapRouter.ExactInputSingleParams memory params = ISwapRouter + .ExactInputSingleParams({ + tokenIn: address(weth), + tokenOut: address(fold), + fee: 10_000, + recipient: msg.sender, + deadline: block.timestamp, + amountIn: 10 ether, + amountOutMinimum: 0, + sqrtPriceLimitX96: 0 + }); // The call to `exactInputSingle` executes the swap. uint256 amountOut = router.exactInputSingle(params); @@ -139,11 +187,11 @@ contract UnitTests is BaseCaptiveTest { // The call to `exactInputSingle` executes the swap. amountOut = router.exactInputSingle(params); - (uint128 amount,,,) = foldCaptiveStaking.balances(User01); + (uint128 amount, , , ) = foldCaptiveStaking.balances(User01); foldCaptiveStaking.compound(); - (uint128 newAmount,,,) = foldCaptiveStaking.balances(User01); + (uint128 newAmount, , , ) = foldCaptiveStaking.balances(User01); assertGt(newAmount, amount); } @@ -167,15 +215,20 @@ contract UnitTests is BaseCaptiveTest { foldCaptiveStaking.deposit(10 ether, 10 ether, 0); - (,, uint128 token0FeeDebt, uint128 token1FeeDebt) = foldCaptiveStaking.balances(User02); + (, , uint128 token0FeeDebt, uint128 token1FeeDebt) = foldCaptiveStaking + .balances(User02); assertEq(token0FeeDebt, foldCaptiveStaking.token0FeesPerLiquidity()); assertEq(token1FeeDebt, foldCaptiveStaking.token1FeesPerLiquidity()); } function testCannotCallbeforeInit() public { - FoldCaptiveStaking stakingTwo = - new FoldCaptiveStaking(address(positionManager), address(pool), address(weth), address(fold)); + FoldCaptiveStaking stakingTwo = new FoldCaptiveStaking( + address(positionManager), + address(pool), + address(weth), + address(fold) + ); vm.expectRevert(NotInitialized.selector); stakingTwo.deposit(0, 0, 0); @@ -203,22 +256,25 @@ contract UnitTests is BaseCaptiveTest { vm.deal(User01, 1000 ether); uint256 initialGlobalRewards = foldCaptiveStaking.rewardsPerLiquidity(); - (, uint256 rewardDebt,,) = foldCaptiveStaking.balances(User01); + (, uint256 rewardDebt, , ) = foldCaptiveStaking.balances(User01); assertEq(rewardDebt, 0); foldCaptiveStaking.depositRewards{value: 1000 ether}(); assertEq(foldCaptiveStaking.rewardsPerLiquidity(), 1000 ether); - assertGt(foldCaptiveStaking.rewardsPerLiquidity(), initialGlobalRewards); + assertGt( + foldCaptiveStaking.rewardsPerLiquidity(), + initialGlobalRewards + ); uint256 initialBalance = weth.balanceOf(User01); foldCaptiveStaking.collectRewards(); - (, rewardDebt,,) = foldCaptiveStaking.balances(User01); + (, rewardDebt, , ) = foldCaptiveStaking.balances(User01); assertEq(rewardDebt, foldCaptiveStaking.rewardsPerLiquidity()); assertGt(weth.balanceOf(User01), initialBalance); - (uint128 liq,,,) = foldCaptiveStaking.balances(User01); + (uint128 liq, , , ) = foldCaptiveStaking.balances(User01); foldCaptiveStaking.withdraw(liq / 3); } @@ -227,7 +283,9 @@ contract UnitTests is BaseCaptiveTest { testAddLiquidity(); // Owner claims insurance - uint128 liquidityToClaim = uint128(foldCaptiveStaking.liquidityUnderManagement() / 4); + uint128 liquidityToClaim = uint128( + foldCaptiveStaking.liquidityUnderManagement() / 4 + ); address owner = foldCaptiveStaking.owner(); vm.startPrank(owner); @@ -246,7 +304,7 @@ contract UnitTests is BaseCaptiveTest { function testProRataWithdrawals() public { testAddLiquidity(); - (uint128 liq,,,) = foldCaptiveStaking.balances(User01); + (uint128 liq, , , ) = foldCaptiveStaking.balances(User01); // Attempt to withdraw more than allowed amount vm.expectRevert(WithdrawProRata.selector); @@ -254,7 +312,7 @@ contract UnitTests is BaseCaptiveTest { // Pro-rated withdrawal foldCaptiveStaking.withdraw(liq / 2); - (uint128 amount,,,) = foldCaptiveStaking.balances(User01); + (uint128 amount, , , ) = foldCaptiveStaking.balances(User01); assertEq(amount, liq / 2); } @@ -262,7 +320,7 @@ contract UnitTests is BaseCaptiveTest { function testZeroDeposit() public { vm.expectRevert(); foldCaptiveStaking.deposit(0, 0, 0); - (uint128 amount,,,) = foldCaptiveStaking.balances(User01); + (uint128 amount, , , ) = foldCaptiveStaking.balances(User01); assertEq(amount, 0); } @@ -271,7 +329,9 @@ contract UnitTests is BaseCaptiveTest { testAddLiquidity(); // Create a reentrancy attack contract and attempt to exploit the staking contract - ReentrancyAttack attack = new ReentrancyAttack(payable(address(foldCaptiveStaking))); + ReentrancyAttack attack = new ReentrancyAttack( + payable(address(foldCaptiveStaking)) + ); fold.transfer(address(attack), 1 ether); weth.transfer(address(attack), 1 ether); @@ -334,10 +394,10 @@ contract UnitTests is BaseCaptiveTest { // User 1 withdraws vm.startPrank(User01); - (uint128 liq,,,) = foldCaptiveStaking.balances(User01); + (uint128 liq, , , ) = foldCaptiveStaking.balances(User01); foldCaptiveStaking.withdraw(liq / 2); - (uint128 amount,,,) = foldCaptiveStaking.balances(User01); + (uint128 amount, , , ) = foldCaptiveStaking.balances(User01); assertEq(amount, liq / 2); vm.stopPrank(); @@ -345,10 +405,10 @@ contract UnitTests is BaseCaptiveTest { // User 2 withdraws vm.startPrank(User02); - (liq,,,) = foldCaptiveStaking.balances(User02); + (liq, , , ) = foldCaptiveStaking.balances(User02); foldCaptiveStaking.withdraw(liq / 2); - (amount,,,) = foldCaptiveStaking.balances(User02); + (amount, , , ) = foldCaptiveStaking.balances(User02); assertEq(amount, liq / 2); vm.stopPrank();