Skip to content

Commit

Permalink
remove oracle withdraw and allow contract owner to withdraw (#11551)
Browse files Browse the repository at this point in the history
* remove oracle withdraw and allow contract owner to withdraw

* offchain changes

* fix integration tests

* fix integration tests

* Remove amount arg in withdraw and withdraw all the withdrawable amount

* Off-chain changes to remove amount arg in withdraw and withdraw all the withdrawable amount

* address comments

* fix lint issue and small refactor in vrf integration tests

* address comment

---------

Co-authored-by: Sri Kidambi <[email protected]>
  • Loading branch information
jinhoonbang and kidambisrinivas authored Jan 2, 2024
1 parent 6b740c5 commit c8eaac7
Show file tree
Hide file tree
Showing 22 changed files with 268 additions and 321 deletions.
22 changes: 12 additions & 10 deletions contracts/src/v0.8/vrf/dev/SubscriptionAPI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ abstract contract SubscriptionAPI is ConfirmedOwner, IERC677Receiver, IVRFSubscr
// A discrepancy with this contract's native balance indicates someone
// sent native using transfer and so we may need to use recoverNativeFunds.
uint96 public s_totalNativeBalance;
mapping(address => uint96) /* oracle */ /* LINK balance */ internal s_withdrawableTokens;
mapping(address => uint96) /* oracle */ /* native balance */ internal s_withdrawableNative;
uint96 internal s_withdrawableTokens;
uint96 internal s_withdrawableNative;

event SubscriptionCreated(uint256 indexed subId, address owner);
event SubscriptionFunded(uint256 indexed subId, uint256 oldBalance, uint256 newBalance);
Expand Down Expand Up @@ -204,35 +204,37 @@ abstract contract SubscriptionAPI is ConfirmedOwner, IERC677Receiver, IVRFSubscr
}

/*
* @notice Oracle withdraw LINK earned through fulfilling requests
* @notice withdraw LINK earned through fulfilling requests
* @param recipient where to send the funds
* @param amount amount to withdraw
*/
function oracleWithdraw(address recipient, uint96 amount) external nonReentrant {
function withdraw(address recipient) external nonReentrant onlyOwner {
if (address(LINK) == address(0)) {
revert LinkNotSet();
}
if (s_withdrawableTokens[msg.sender] < amount) {
if (s_withdrawableTokens == 0) {
revert InsufficientBalance();
}
s_withdrawableTokens[msg.sender] -= amount;
uint96 amount = s_withdrawableTokens;
s_withdrawableTokens -= amount;
s_totalBalance -= amount;
if (!LINK.transfer(recipient, amount)) {
revert InsufficientBalance();
}
}

/*
* @notice Oracle withdraw native earned through fulfilling requests
* @notice withdraw native earned through fulfilling requests
* @param recipient where to send the funds
* @param amount amount to withdraw
*/
function oracleWithdrawNative(address payable recipient, uint96 amount) external nonReentrant {
if (s_withdrawableNative[msg.sender] < amount) {
function withdrawNative(address payable recipient) external nonReentrant onlyOwner {
if (s_withdrawableNative == 0) {
revert InsufficientBalance();
}
// Prevent re-entrancy by updating state before transfer.
s_withdrawableNative[msg.sender] -= amount;
uint96 amount = s_withdrawableNative;
s_withdrawableNative -= amount;
s_totalNativeBalance -= amount;
(bool sent, ) = recipient.call{value: amount}("");
if (!sent) {
Expand Down
33 changes: 15 additions & 18 deletions contracts/src/v0.8/vrf/dev/VRFCoordinatorV2_5.sol
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
address sender;
bytes extraArgs;
}
mapping(bytes32 => address) /* keyHash */ /* oracle */ public s_provingKeys;
mapping(bytes32 => bool) /* keyHash */ /* exists */ public s_provingKeys;
bytes32[] public s_provingKeyHashes;
mapping(uint256 => bytes32) /* requestID */ /* commitment */ public s_requestCommitments;
event ProvingKeyRegistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyDeregistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyRegistered(bytes32 keyHash);
event ProvingKeyDeregistered(bytes32 keyHash);
event RandomWordsRequested(
bytes32 indexed keyHash,
uint256 requestId,
Expand Down Expand Up @@ -94,28 +94,26 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
}

/**
* @notice Registers a proving key to an oracle.
* @param oracle address of the oracle
* @notice Registers a proving key to.
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function registerProvingKey(address oracle, uint256[2] calldata publicProvingKey) external onlyOwner {
function registerProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
bytes32 kh = hashOfKey(publicProvingKey);
if (s_provingKeys[kh] != address(0)) {
if (s_provingKeys[kh]) {
revert ProvingKeyAlreadyRegistered(kh);
}
s_provingKeys[kh] = oracle;
s_provingKeys[kh] = true;
s_provingKeyHashes.push(kh);
emit ProvingKeyRegistered(kh, oracle);
emit ProvingKeyRegistered(kh);
}

/**
* @notice Deregisters a proving key to an oracle.
* @notice Deregisters a proving key.
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function deregisterProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
bytes32 kh = hashOfKey(publicProvingKey);
address oracle = s_provingKeys[kh];
if (oracle == address(0)) {
if (!s_provingKeys[kh]) {
revert NoSuchProvingKey(kh);
}
delete s_provingKeys[kh];
Expand All @@ -127,7 +125,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
s_provingKeyHashes.pop();
}
}
emit ProvingKeyDeregistered(kh, oracle);
emit ProvingKeyDeregistered(kh);
}

/**
Expand Down Expand Up @@ -355,8 +353,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
) internal view returns (Output memory) {
bytes32 keyHash = hashOfKey(proof.pk);
// Only registered proving keys are permitted.
address oracle = s_provingKeys[keyHash];
if (oracle == address(0)) {
if (!s_provingKeys[keyHash]) {
revert NoSuchProvingKey(keyHash);
}
uint256 requestId = uint256(keccak256(abi.encode(keyHash, proof.seed)));
Expand Down Expand Up @@ -423,7 +420,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
bool nativePayment = uint8(rc.extraArgs[rc.extraArgs.length - 1]) == 1;
// We want to charge users exactly for how much gas they use in their callback.
// The gasAfterPaymentCalculation is meant to cover these additional operations where we
// decrement the subscription balance and increment the oracles withdrawable balance.
// decrement the subscription balance and increment the withdrawable balance.
uint96 payment = _calculatePaymentAmount(
startGas,
s_config.gasAfterPaymentCalculation,
Expand All @@ -435,13 +432,13 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].nativeBalance -= payment;
s_withdrawableNative[s_provingKeys[output.keyHash]] += payment;
s_withdrawableNative += payment;
} else {
if (s_subscriptions[rc.subId].balance < payment) {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].balance -= payment;
s_withdrawableTokens[s_provingKeys[output.keyHash]] += payment;
s_withdrawableTokens += payment;
}

// Include payment in the event for tracking costs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ contract ExposedVRFCoordinatorV2_5 is VRFCoordinatorV2_5 {
s_totalNativeBalance = newBalance;
}

function setWithdrawableTokensTestingOnlyXXX(address oracle, uint96 newBalance) external {
s_withdrawableTokens[oracle] = newBalance;
function setWithdrawableTokensTestingOnlyXXX(uint96 newBalance) external {
s_withdrawableTokens = newBalance;
}

function getWithdrawableTokensTestingOnlyXXX(address oracle) external view returns (uint96) {
return s_withdrawableTokens[oracle];
function getWithdrawableTokensTestingOnlyXXX() external view returns (uint96) {
return s_withdrawableTokens;
}

function setWithdrawableNativeTestingOnlyXXX(address oracle, uint96 newBalance) external {
s_withdrawableNative[oracle] = newBalance;
function setWithdrawableNativeTestingOnlyXXX(uint96 newBalance) external {
s_withdrawableNative = newBalance;
}

function getWithdrawableNativeTestingOnlyXXX(address oracle) external view returns (uint96) {
return s_withdrawableNative[oracle];
function getWithdrawableNativeTestingOnlyXXX() external view returns (uint96) {
return s_withdrawableNative;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
bytes extraArgs;
}

mapping(bytes32 => address) /* keyHash */ /* oracle */ internal s_provingKeys;
mapping(bytes32 => bool) /* keyHash */ /* exists */ internal s_provingKeys;
bytes32[] public s_provingKeyHashes;
mapping(uint256 => bytes32) /* requestID */ /* commitment */ public s_requestCommitments;

event ProvingKeyRegistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyRegistered(bytes32 keyHash);
event RandomWordsRequested(
bytes32 indexed keyHash,
uint256 requestId,
Expand Down Expand Up @@ -108,17 +108,16 @@ contract VRFCoordinatorV2PlusUpgradedVersion is

/**
* @notice Registers a proving key to an oracle.
* @param oracle address of the oracle
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function registerProvingKey(address oracle, uint256[2] calldata publicProvingKey) external onlyOwner {
function registerProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
bytes32 kh = hashOfKey(publicProvingKey);
if (s_provingKeys[kh] != address(0)) {
if (s_provingKeys[kh]) {
revert ProvingKeyAlreadyRegistered(kh);
}
s_provingKeys[kh] = oracle;
s_provingKeys[kh] = true;
s_provingKeyHashes.push(kh);
emit ProvingKeyRegistered(kh, oracle);
emit ProvingKeyRegistered(kh);
}

/**
Expand Down Expand Up @@ -346,8 +345,7 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
) internal view returns (Output memory) {
bytes32 keyHash = hashOfKey(proof.pk);
// Only registered proving keys are permitted.
address oracle = s_provingKeys[keyHash];
if (oracle == address(0)) {
if (!s_provingKeys[keyHash]) {
revert NoSuchProvingKey(keyHash);
}
uint256 requestId = uint256(keccak256(abi.encode(keyHash, proof.seed)));
Expand Down Expand Up @@ -426,13 +424,13 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
revert InsufficientBalance();
}
s_subscriptions[rc.subId].nativeBalance -= payment;
s_withdrawableNative[s_provingKeys[output.keyHash]] += payment;
s_withdrawableNative += payment;
} else {
if (s_subscriptions[rc.subId].balance < payment) {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].balance -= payment;
s_withdrawableTokens[s_provingKeys[output.keyHash]] += payment;
s_withdrawableTokens += payment;
}

// Include payment in the event for tracking costs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ contract VRFCoordinatorV2Plus_Migration is BaseTest {

function registerProvingKey() public {
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(UNCOMPRESSED_PUBLIC_KEY);
v1Coordinator.registerProvingKey(OWNER, uncompressedKeyParts);
v1Coordinator_noLink.registerProvingKey(OWNER, uncompressedKeyParts);
v1Coordinator.registerProvingKey(uncompressedKeyParts);
v1Coordinator_noLink.registerProvingKey(uncompressedKeyParts);
}

// note: Call this function via this.getProvingKeyParts to be able to pass memory as calldata and
Expand Down
4 changes: 2 additions & 2 deletions contracts/test/v0.8/foundry/vrf/VRFV2Plus.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ contract VRFV2Plus is BaseTest {
// Should revert when already registered.
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(vrfUncompressedPublicKey);
vm.expectRevert(abi.encodeWithSelector(VRFCoordinatorV2_5.ProvingKeyAlreadyRegistered.selector, vrfKeyHash));
s_testCoordinator.registerProvingKey(LINK_WHALE, uncompressedKeyParts);
s_testCoordinator.registerProvingKey(uncompressedKeyParts);
}

function registerProvingKey() public {
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(vrfUncompressedPublicKey);
s_testCoordinator.registerProvingKey(LINK_WHALE, uncompressedKeyParts);
s_testCoordinator.registerProvingKey(uncompressedKeyParts);
}

// note: Call this function via this.getProvingKeyParts to be able to pass memory as calldata and
Expand Down
74 changes: 44 additions & 30 deletions contracts/test/v0.8/foundry/vrf/VRFV2PlusSubscriptionAPI.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -315,25 +315,25 @@ contract VRFV2PlusSubscriptionAPITest is BaseTest {
assertEq(address(s_subscriptionAPI).balance, s_subscriptionAPI.s_totalNativeBalance());
}

function testOracleWithdrawNoLink() public {
function testWithdrawNoLink() public {
// CASE: no link token set
vm.expectRevert(SubscriptionAPI.LinkNotSet.selector);
s_subscriptionAPI.oracleWithdraw(OWNER, 1 ether);
s_subscriptionAPI.withdraw(OWNER);
}

function testOracleWithdrawInsufficientBalance() public {
function testWithdrawInsufficientBalance() public {
// CASE: link token set, trying to withdraw
// more than balance
MockLinkToken linkToken = new MockLinkToken();
s_subscriptionAPI.setLINKAndLINKNativeFeed(address(linkToken), address(0));
assertEq(address(s_subscriptionAPI.LINK()), address(linkToken));

// call oracleWithdraw
// call withdraw
vm.expectRevert(SubscriptionAPI.InsufficientBalance.selector);
s_subscriptionAPI.oracleWithdraw(OWNER, 1 ether);
s_subscriptionAPI.withdraw(OWNER);
}

function testOracleWithdrawSufficientBalanceLinkSet() public {
function testWithdrawSufficientBalanceLinkSet() public {
// CASE: link token set, trying to withdraw
// less than balance
MockLinkToken linkToken = new MockLinkToken();
Expand All @@ -344,58 +344,72 @@ contract VRFV2PlusSubscriptionAPITest is BaseTest {
bool success = linkToken.transfer(address(s_subscriptionAPI), 10 ether);
assertTrue(success, "failed link transfer");

// set the withdrawable tokens of the oracle to be 1 ether
address oracle = makeAddr("oracle");
s_subscriptionAPI.setWithdrawableTokensTestingOnlyXXX(oracle, 1 ether);
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(oracle), 1 ether);
// set the withdrawable tokens of the contract to be 1 ether
s_subscriptionAPI.setWithdrawableTokensTestingOnlyXXX(1 ether);
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(), 1 ether);

// set the total balance to be the same as the link balance for consistency
// (this is not necessary for the test, but just to be sane)
s_subscriptionAPI.setTotalBalanceTestingOnlyXXX(10 ether);

// call oracleWithdraw from oracle address
changePrank(oracle);
s_subscriptionAPI.oracleWithdraw(oracle, 1 ether);
// assert link balance of oracle
assertEq(linkToken.balanceOf(oracle), 1 ether, "oracle link balance incorrect");
// call Withdraw from owner address
uint256 ownerBalance = linkToken.balanceOf(OWNER);
changePrank(OWNER);
s_subscriptionAPI.withdraw(OWNER);
// assert link balance of owner
assertEq(linkToken.balanceOf(OWNER) - ownerBalance, 1 ether, "owner link balance incorrect");
// assert state of subscription api
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(oracle), 0, "oracle withdrawable tokens incorrect");
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(), 0, "owner withdrawable tokens incorrect");
// assert that total balance is changed by the withdrawn amount
assertEq(s_subscriptionAPI.s_totalBalance(), 9 ether, "total balance incorrect");
}

function testOracleWithdrawNativeInsufficientBalance() public {
function testWithdrawNativeInsufficientBalance() public {
// CASE: trying to withdraw more than balance
// should revert with InsufficientBalance

// call oracleWithdrawNative
// call WithdrawNative
changePrank(OWNER);
vm.expectRevert(SubscriptionAPI.InsufficientBalance.selector);
s_subscriptionAPI.oracleWithdrawNative(payable(OWNER), 1 ether);
s_subscriptionAPI.withdrawNative(payable(OWNER));
}

function testWithdrawLinkInvalidOwner() public {
address invalidAddress = makeAddr("invalidAddress");
changePrank(invalidAddress);
vm.expectRevert("Only callable by owner");
s_subscriptionAPI.withdraw(payable(OWNER));
}

function testOracleWithdrawNativeSufficientBalance() public {
function testWithdrawNativeInvalidOwner() public {
address invalidAddress = makeAddr("invalidAddress");
changePrank(invalidAddress);
vm.expectRevert("Only callable by owner");
s_subscriptionAPI.withdrawNative(payable(OWNER));
}

function testWithdrawNativeSufficientBalance() public {
// CASE: trying to withdraw less than balance
// should withdraw successfully

// transfer 10 ether to the contract to withdraw
vm.deal(address(s_subscriptionAPI), 10 ether);

// set the withdrawable eth of the oracle to be 1 ether
address oracle = makeAddr("oracle");
s_subscriptionAPI.setWithdrawableNativeTestingOnlyXXX(oracle, 1 ether);
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(oracle), 1 ether);
// set the withdrawable eth of the contract to be 1 ether
s_subscriptionAPI.setWithdrawableNativeTestingOnlyXXX(1 ether);
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(), 1 ether);

// set the total balance to be the same as the eth balance for consistency
// (this is not necessary for the test, but just to be sane)
s_subscriptionAPI.setTotalNativeBalanceTestingOnlyXXX(10 ether);

// call oracleWithdrawNative from oracle address
changePrank(oracle);
s_subscriptionAPI.oracleWithdrawNative(payable(oracle), 1 ether);
// assert native balance of oracle
assertEq(address(oracle).balance, 1 ether, "oracle native balance incorrect");
// call WithdrawNative from owner address
changePrank(OWNER);
s_subscriptionAPI.withdrawNative(payable(OWNER));
// assert native balance
assertEq(address(OWNER).balance, 1 ether, "owner native balance incorrect");
// assert state of subscription api
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(oracle), 0, "oracle withdrawable native incorrect");
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(), 0, "owner withdrawable native incorrect");
// assert that total balance is changed by the withdrawn amount
assertEq(s_subscriptionAPI.s_totalNativeBalance(), 9 ether, "total native balance incorrect");
}
Expand Down
Loading

0 comments on commit c8eaac7

Please sign in to comment.