Skip to content

Commit

Permalink
feat: add views for fees and rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
patcito committed Aug 15, 2024
1 parent 76e7fb4 commit b2803a8
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 57 deletions.
35 changes: 24 additions & 11 deletions src/FoldCaptiveStaking.sol
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,7 @@ contract FoldCaptiveStaking is Owned(msg.sender) {
function compound() public isInitialized {
collectPositionFees();

uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount
/ liquidityUnderManagement;
uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount
/ liquidityUnderManagement;
(uint256 fee0Owed, uint256 fee1Owed) = owedFees();

INonfungiblePositionManager.IncreaseLiquidityParams memory params = INonfungiblePositionManager
.IncreaseLiquidityParams({
Expand All @@ -243,14 +240,23 @@ contract FoldCaptiveStaking is Owned(msg.sender) {
emit Compounded(msg.sender, liquidity, fee0Owed, fee1Owed);
}

/// @notice User-specific function to view fees owed on the singular position
function owedFees() public view returns (uint256, uint256) {
uint256 fee0Owed = ((token0FeesPerLiquidity -
balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount) /
liquidityUnderManagement;
uint256 fee1Owed = ((token1FeesPerLiquidity -
balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount) /
liquidityUnderManagement;

return (fee0Owed, fee1Owed);
}

/// @notice User-specific function to collect fees on the singular position
function collectFees() public isInitialized {
collectPositionFees();

uint256 fee0Owed = (token0FeesPerLiquidity - balances[msg.sender].token0FeeDebt) * balances[msg.sender].amount
/ liquidityUnderManagement;
uint256 fee1Owed = (token1FeesPerLiquidity - balances[msg.sender].token1FeeDebt) * balances[msg.sender].amount
/ liquidityUnderManagement;
(uint256 fee0Owed, uint256 fee1Owed) = owedFees();

token0.transfer(msg.sender, fee0Owed);
token1.transfer(msg.sender, fee1Owed);
Expand All @@ -261,10 +267,16 @@ contract FoldCaptiveStaking is Owned(msg.sender) {
emit FeesCollected(msg.sender, fee0Owed, fee1Owed);
}

/// @notice User-specific function to view rewards owed on the singular position
function owedRewards() public view returns (uint256) {
return
((rewardsPerLiquidity - balances[msg.sender].rewardDebt) *
balances[msg.sender].amount) / liquidityUnderManagement;
}

/// @notice User-specific Rewards for Protocol Rewards
function collectRewards() public isInitialized {
uint256 rewardsOwed = (rewardsPerLiquidity - balances[msg.sender].rewardDebt) * balances[msg.sender].amount
/ liquidityUnderManagement;
uint256 rewardsOwed = owedRewards();

WETH9.transfer(msg.sender, rewardsOwed);

Expand Down Expand Up @@ -373,7 +385,8 @@ contract FoldCaptiveStaking is Owned(msg.sender) {
amount1Max: uint128(amount1)
});

(uint256 amount0Collected, uint256 amount1Collected) = positionManager.collect(collectParams);
(uint256 amount0Collected, uint256 amount1Collected) = positionManager
.collect(collectParams);

if (amount0Collected != amount0 || amount1Collected != amount1) {
revert WithdrawFailed();
Expand Down
152 changes: 106 additions & 46 deletions test/UnitTests.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import "test/BaseCaptiveTest.sol";
import "test/interfaces/ISwapRouter.sol";

contract UnitTests is BaseCaptiveTest {
ISwapRouter public router = ISwapRouter(0xE592427A0AEce92De3Edee1F18E0157C05861564);
ISwapRouter public router =
ISwapRouter(0xE592427A0AEce92De3Edee1F18E0157C05861564);

/// @dev Ensure that balances and state variables are updated correctly.
function testAddLiquidity() public {
Expand All @@ -18,8 +19,12 @@ contract UnitTests is BaseCaptiveTest {
fold.approve(address(foldCaptiveStaking), type(uint256).max);

foldCaptiveStaking.deposit(1_000 ether, 1_000 ether, 0);
(uint128 amount, uint128 rewardDebt, uint128 token0FeeDebt, uint128 token1FeeDebt) =
foldCaptiveStaking.balances(User01);
(
uint128 amount,
uint128 rewardDebt,
uint128 token0FeeDebt,
uint128 token1FeeDebt
) = foldCaptiveStaking.balances(User01);

assertGt(amount, 0);
assertEq(rewardDebt, 0);
Expand All @@ -31,13 +36,18 @@ contract UnitTests is BaseCaptiveTest {
function testRemoveLiquidity() public {
testAddLiquidity();

(uint128 amount, uint128 rewardDebt, uint128 token0FeeDebt, uint128 token1FeeDebt) =
foldCaptiveStaking.balances(User01);
(
uint128 amount,
uint128 rewardDebt,
uint128 token0FeeDebt,
uint128 token1FeeDebt
) = foldCaptiveStaking.balances(User01);

(uint128 liq,,,) = foldCaptiveStaking.balances(User01);
(uint128 liq, , , ) = foldCaptiveStaking.balances(User01);
foldCaptiveStaking.withdraw(liq / 2);

(amount, rewardDebt, token0FeeDebt, token1FeeDebt) = foldCaptiveStaking.balances(User01);
(amount, rewardDebt, token0FeeDebt, token1FeeDebt) = foldCaptiveStaking
.balances(User01);

assertEq(amount, liq / 2);
assertEq(rewardDebt, 0);
Expand All @@ -46,11 +56,47 @@ contract UnitTests is BaseCaptiveTest {

foldCaptiveStaking.withdraw(liq / 4);

(amount,,,) = foldCaptiveStaking.balances(User01);
(amount, , , ) = foldCaptiveStaking.balances(User01);

assertEq(amount, liq / 4);
}

/// @dev Ensure that owed fees are returned correctly.
function testOwedFees() public {
testAddLiquidity();
uint256 owedReards = foldCaptiveStaking.owedRewards();
(uint256 amount, uint256 rewardDebt, , ) = foldCaptiveStaking.balances(
User01
);
uint256 rewardOwedCheck = ((foldCaptiveStaking.rewardsPerLiquidity() -
rewardDebt) * amount) /
foldCaptiveStaking.liquidityUnderManagement();

assertEq(rewardOwedCheck, owedReards);
}

/// @dev Ensure that owed fees are returned correctly.
function testOwedRewards() public {
testAddLiquidity();
(uint256 fee0Owed, uint256 fee1Owed) = foldCaptiveStaking.owedFees();
(
uint256 amount,
,
uint256 token0FeeDebt,
uint256 token1FeeDebt
) = foldCaptiveStaking.balances(User01);
uint256 fee0OwedCheck = ((foldCaptiveStaking.token0FeesPerLiquidity() -
token0FeeDebt) * amount) /
foldCaptiveStaking.liquidityUnderManagement();

uint256 fee1OwedCheck = ((foldCaptiveStaking.token1FeesPerLiquidity() -
token1FeeDebt) * amount) /
foldCaptiveStaking.liquidityUnderManagement();

assertEq(fee0OwedCheck, fee0Owed);
assertEq(fee1OwedCheck, fee1Owed);
}

/// @dev Ensure fees are accrued correctly and distributed proportionately.
function testFeesAccrue() public {
testAddLiquidity();
Expand All @@ -59,16 +105,17 @@ contract UnitTests is BaseCaptiveTest {
weth.deposit{value: 10 ether}();
weth.approve(address(router), type(uint256).max);

ISwapRouter.ExactInputSingleParams memory params = ISwapRouter.ExactInputSingleParams({
tokenIn: address(weth),
tokenOut: address(fold),
fee: 10_000,
recipient: msg.sender,
deadline: block.timestamp,
amountIn: 10 ether,
amountOutMinimum: 0,
sqrtPriceLimitX96: 0
});
ISwapRouter.ExactInputSingleParams memory params = ISwapRouter
.ExactInputSingleParams({
tokenIn: address(weth),
tokenOut: address(fold),
fee: 10_000,
recipient: msg.sender,
deadline: block.timestamp,
amountIn: 10 ether,
amountOutMinimum: 0,
sqrtPriceLimitX96: 0
});

// The call to `exactInputSingle` executes the swap.
uint256 amountOut = router.exactInputSingle(params);
Expand Down Expand Up @@ -108,16 +155,17 @@ contract UnitTests is BaseCaptiveTest {
weth.deposit{value: 10 ether}();
weth.approve(address(router), type(uint256).max);

ISwapRouter.ExactInputSingleParams memory params = ISwapRouter.ExactInputSingleParams({
tokenIn: address(weth),
tokenOut: address(fold),
fee: 10_000,
recipient: msg.sender,
deadline: block.timestamp,
amountIn: 10 ether,
amountOutMinimum: 0,
sqrtPriceLimitX96: 0
});
ISwapRouter.ExactInputSingleParams memory params = ISwapRouter
.ExactInputSingleParams({
tokenIn: address(weth),
tokenOut: address(fold),
fee: 10_000,
recipient: msg.sender,
deadline: block.timestamp,
amountIn: 10 ether,
amountOutMinimum: 0,
sqrtPriceLimitX96: 0
});

// The call to `exactInputSingle` executes the swap.
uint256 amountOut = router.exactInputSingle(params);
Expand All @@ -139,11 +187,11 @@ contract UnitTests is BaseCaptiveTest {
// The call to `exactInputSingle` executes the swap.
amountOut = router.exactInputSingle(params);

(uint128 amount,,,) = foldCaptiveStaking.balances(User01);
(uint128 amount, , , ) = foldCaptiveStaking.balances(User01);

foldCaptiveStaking.compound();

(uint128 newAmount,,,) = foldCaptiveStaking.balances(User01);
(uint128 newAmount, , , ) = foldCaptiveStaking.balances(User01);

assertGt(newAmount, amount);
}
Expand All @@ -167,15 +215,20 @@ contract UnitTests is BaseCaptiveTest {

foldCaptiveStaking.deposit(10 ether, 10 ether, 0);

(,, uint128 token0FeeDebt, uint128 token1FeeDebt) = foldCaptiveStaking.balances(User02);
(, , uint128 token0FeeDebt, uint128 token1FeeDebt) = foldCaptiveStaking
.balances(User02);

assertEq(token0FeeDebt, foldCaptiveStaking.token0FeesPerLiquidity());
assertEq(token1FeeDebt, foldCaptiveStaking.token1FeesPerLiquidity());
}

function testCannotCallbeforeInit() public {
FoldCaptiveStaking stakingTwo =
new FoldCaptiveStaking(address(positionManager), address(pool), address(weth), address(fold));
FoldCaptiveStaking stakingTwo = new FoldCaptiveStaking(
address(positionManager),
address(pool),
address(weth),
address(fold)
);

vm.expectRevert(NotInitialized.selector);
stakingTwo.deposit(0, 0, 0);
Expand Down Expand Up @@ -203,22 +256,25 @@ contract UnitTests is BaseCaptiveTest {
vm.deal(User01, 1000 ether);

uint256 initialGlobalRewards = foldCaptiveStaking.rewardsPerLiquidity();
(, uint256 rewardDebt,,) = foldCaptiveStaking.balances(User01);
(, uint256 rewardDebt, , ) = foldCaptiveStaking.balances(User01);
assertEq(rewardDebt, 0);

foldCaptiveStaking.depositRewards{value: 1000 ether}();

assertEq(foldCaptiveStaking.rewardsPerLiquidity(), 1000 ether);
assertGt(foldCaptiveStaking.rewardsPerLiquidity(), initialGlobalRewards);
assertGt(
foldCaptiveStaking.rewardsPerLiquidity(),
initialGlobalRewards
);

uint256 initialBalance = weth.balanceOf(User01);
foldCaptiveStaking.collectRewards();

(, rewardDebt,,) = foldCaptiveStaking.balances(User01);
(, rewardDebt, , ) = foldCaptiveStaking.balances(User01);
assertEq(rewardDebt, foldCaptiveStaking.rewardsPerLiquidity());
assertGt(weth.balanceOf(User01), initialBalance);

(uint128 liq,,,) = foldCaptiveStaking.balances(User01);
(uint128 liq, , , ) = foldCaptiveStaking.balances(User01);
foldCaptiveStaking.withdraw(liq / 3);
}

Expand All @@ -227,7 +283,9 @@ contract UnitTests is BaseCaptiveTest {
testAddLiquidity();

// Owner claims insurance
uint128 liquidityToClaim = uint128(foldCaptiveStaking.liquidityUnderManagement() / 4);
uint128 liquidityToClaim = uint128(
foldCaptiveStaking.liquidityUnderManagement() / 4
);

address owner = foldCaptiveStaking.owner();
vm.startPrank(owner);
Expand All @@ -246,23 +304,23 @@ contract UnitTests is BaseCaptiveTest {
function testProRataWithdrawals() public {
testAddLiquidity();

(uint128 liq,,,) = foldCaptiveStaking.balances(User01);
(uint128 liq, , , ) = foldCaptiveStaking.balances(User01);

// Attempt to withdraw more than allowed amount
vm.expectRevert(WithdrawProRata.selector);
foldCaptiveStaking.withdraw(liq);

// Pro-rated withdrawal
foldCaptiveStaking.withdraw(liq / 2);
(uint128 amount,,,) = foldCaptiveStaking.balances(User01);
(uint128 amount, , , ) = foldCaptiveStaking.balances(User01);
assertEq(amount, liq / 2);
}

/// @dev Ensure zero deposits are handled correctly and revert as expected.
function testZeroDeposit() public {
vm.expectRevert();
foldCaptiveStaking.deposit(0, 0, 0);
(uint128 amount,,,) = foldCaptiveStaking.balances(User01);
(uint128 amount, , , ) = foldCaptiveStaking.balances(User01);
assertEq(amount, 0);
}

Expand All @@ -271,7 +329,9 @@ contract UnitTests is BaseCaptiveTest {
testAddLiquidity();

// Create a reentrancy attack contract and attempt to exploit the staking contract
ReentrancyAttack attack = new ReentrancyAttack(payable(address(foldCaptiveStaking)));
ReentrancyAttack attack = new ReentrancyAttack(
payable(address(foldCaptiveStaking))
);
fold.transfer(address(attack), 1 ether);
weth.transfer(address(attack), 1 ether);

Expand Down Expand Up @@ -334,21 +394,21 @@ contract UnitTests is BaseCaptiveTest {
// User 1 withdraws
vm.startPrank(User01);

(uint128 liq,,,) = foldCaptiveStaking.balances(User01);
(uint128 liq, , , ) = foldCaptiveStaking.balances(User01);
foldCaptiveStaking.withdraw(liq / 2);

(uint128 amount,,,) = foldCaptiveStaking.balances(User01);
(uint128 amount, , , ) = foldCaptiveStaking.balances(User01);
assertEq(amount, liq / 2);

vm.stopPrank();

// User 2 withdraws
vm.startPrank(User02);

(liq,,,) = foldCaptiveStaking.balances(User02);
(liq, , , ) = foldCaptiveStaking.balances(User02);
foldCaptiveStaking.withdraw(liq / 2);

(amount,,,) = foldCaptiveStaking.balances(User02);
(amount, , , ) = foldCaptiveStaking.balances(User02);
assertEq(amount, liq / 2);

vm.stopPrank();
Expand Down

0 comments on commit b2803a8

Please sign in to comment.