Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove oracle withdraw and allow contract owner to withdraw #11551

22 changes: 12 additions & 10 deletions contracts/src/v0.8/vrf/dev/SubscriptionAPI.sol
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ abstract contract SubscriptionAPI is ConfirmedOwner, IERC677Receiver, IVRFSubscr
// A discrepancy with this contract's native balance indicates someone
// sent native using transfer and so we may need to use recoverNativeFunds.
uint96 public s_totalNativeBalance;
mapping(address => uint96) /* oracle */ /* LINK balance */ internal s_withdrawableTokens;
mapping(address => uint96) /* oracle */ /* native balance */ internal s_withdrawableNative;
uint96 internal s_withdrawableTokens;
uint96 internal s_withdrawableNative;

event SubscriptionCreated(uint256 indexed subId, address owner);
event SubscriptionFunded(uint256 indexed subId, uint256 oldBalance, uint256 newBalance);
Expand Down Expand Up @@ -204,35 +204,37 @@ abstract contract SubscriptionAPI is ConfirmedOwner, IERC677Receiver, IVRFSubscr
}

/*
* @notice Oracle withdraw LINK earned through fulfilling requests
* @notice withdraw LINK earned through fulfilling requests
* @param recipient where to send the funds
* @param amount amount to withdraw
*/
function oracleWithdraw(address recipient, uint96 amount) external nonReentrant {
function withdraw(address recipient) external nonReentrant onlyOwner {
if (address(LINK) == address(0)) {
revert LinkNotSet();
}
if (s_withdrawableTokens[msg.sender] < amount) {
if (s_withdrawableTokens == 0) {
revert InsufficientBalance();
}
s_withdrawableTokens[msg.sender] -= amount;
uint96 amount = s_withdrawableTokens;
s_withdrawableTokens -= amount;
s_totalBalance -= amount;
if (!LINK.transfer(recipient, amount)) {
revert InsufficientBalance();
}
}

/*
* @notice Oracle withdraw native earned through fulfilling requests
* @notice withdraw native earned through fulfilling requests
* @param recipient where to send the funds
* @param amount amount to withdraw
*/
function oracleWithdrawNative(address payable recipient, uint96 amount) external nonReentrant {
if (s_withdrawableNative[msg.sender] < amount) {
function withdrawNative(address payable recipient) external nonReentrant onlyOwner {
if (s_withdrawableNative == 0) {
revert InsufficientBalance();
}
// Prevent re-entrancy by updating state before transfer.
s_withdrawableNative[msg.sender] -= amount;
uint96 amount = s_withdrawableNative;
s_withdrawableNative -= amount;
s_totalNativeBalance -= amount;
(bool sent, ) = recipient.call{value: amount}("");
if (!sent) {
Expand Down
33 changes: 15 additions & 18 deletions contracts/src/v0.8/vrf/dev/VRFCoordinatorV2_5.sol
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
address sender;
bytes extraArgs;
}
mapping(bytes32 => address) /* keyHash */ /* oracle */ public s_provingKeys;
mapping(bytes32 => bool) /* keyHash */ /* exists */ public s_provingKeys;
bytes32[] public s_provingKeyHashes;
mapping(uint256 => bytes32) /* requestID */ /* commitment */ public s_requestCommitments;
event ProvingKeyRegistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyDeregistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyRegistered(bytes32 keyHash);
event ProvingKeyDeregistered(bytes32 keyHash);
kidambisrinivas marked this conversation as resolved.
Show resolved Hide resolved
event RandomWordsRequested(
bytes32 indexed keyHash,
uint256 requestId,
Expand Down Expand Up @@ -94,28 +94,26 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
}

/**
* @notice Registers a proving key to an oracle.
* @param oracle address of the oracle
* @notice Registers a proving key to.
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function registerProvingKey(address oracle, uint256[2] calldata publicProvingKey) external onlyOwner {
function registerProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinhoonbang @kidambisrinivas FYI this is a breaking change for Gauntlet commands (both for register and deregister proving key). When I use this branch, unit tests on the Gauntlet repo for proving key registration and deregistration are failing. I will fix those commands in the same PR where I would be introducing new withdrawal functions. Then, we would merge it on the Gauntlet develop branch only when this PR is merged (and after I double-check everything once again).

bytes32 kh = hashOfKey(publicProvingKey);
if (s_provingKeys[kh] != address(0)) {
if (s_provingKeys[kh]) {
revert ProvingKeyAlreadyRegistered(kh);
}
s_provingKeys[kh] = oracle;
s_provingKeys[kh] = true;
s_provingKeyHashes.push(kh);
emit ProvingKeyRegistered(kh, oracle);
emit ProvingKeyRegistered(kh);
}

/**
* @notice Deregisters a proving key to an oracle.
* @notice Deregisters a proving key.
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function deregisterProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
bytes32 kh = hashOfKey(publicProvingKey);
address oracle = s_provingKeys[kh];
if (oracle == address(0)) {
if (!s_provingKeys[kh]) {
revert NoSuchProvingKey(kh);
}
delete s_provingKeys[kh];
Expand All @@ -127,7 +125,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
s_provingKeyHashes.pop();
}
}
emit ProvingKeyDeregistered(kh, oracle);
emit ProvingKeyDeregistered(kh);
}

/**
Expand Down Expand Up @@ -355,8 +353,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
) internal view returns (Output memory) {
bytes32 keyHash = hashOfKey(proof.pk);
// Only registered proving keys are permitted.
address oracle = s_provingKeys[keyHash];
if (oracle == address(0)) {
if (!s_provingKeys[keyHash]) {
revert NoSuchProvingKey(keyHash);
}
uint256 requestId = uint256(keccak256(abi.encode(keyHash, proof.seed)));
Expand Down Expand Up @@ -423,7 +420,7 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
bool nativePayment = uint8(rc.extraArgs[rc.extraArgs.length - 1]) == 1;
// We want to charge users exactly for how much gas they use in their callback.
// The gasAfterPaymentCalculation is meant to cover these additional operations where we
// decrement the subscription balance and increment the oracles withdrawable balance.
// decrement the subscription balance and increment the withdrawable balance.
uint96 payment = _calculatePaymentAmount(
startGas,
s_config.gasAfterPaymentCalculation,
Expand All @@ -435,13 +432,13 @@ contract VRFCoordinatorV2_5 is VRF, SubscriptionAPI, IVRFCoordinatorV2Plus {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].nativeBalance -= payment;
s_withdrawableNative[s_provingKeys[output.keyHash]] += payment;
s_withdrawableNative += payment;
} else {
if (s_subscriptions[rc.subId].balance < payment) {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].balance -= payment;
s_withdrawableTokens[s_provingKeys[output.keyHash]] += payment;
s_withdrawableTokens += payment;
}

// Include payment in the event for tracking costs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ contract ExposedVRFCoordinatorV2_5 is VRFCoordinatorV2_5 {
s_totalNativeBalance = newBalance;
}

function setWithdrawableTokensTestingOnlyXXX(address oracle, uint96 newBalance) external {
s_withdrawableTokens[oracle] = newBalance;
function setWithdrawableTokensTestingOnlyXXX(uint96 newBalance) external {
s_withdrawableTokens = newBalance;
}

function getWithdrawableTokensTestingOnlyXXX(address oracle) external view returns (uint96) {
return s_withdrawableTokens[oracle];
function getWithdrawableTokensTestingOnlyXXX() external view returns (uint96) {
return s_withdrawableTokens;
}

function setWithdrawableNativeTestingOnlyXXX(address oracle, uint96 newBalance) external {
s_withdrawableNative[oracle] = newBalance;
function setWithdrawableNativeTestingOnlyXXX(uint96 newBalance) external {
s_withdrawableNative = newBalance;
}

function getWithdrawableNativeTestingOnlyXXX(address oracle) external view returns (uint96) {
return s_withdrawableNative[oracle];
function getWithdrawableNativeTestingOnlyXXX() external view returns (uint96) {
return s_withdrawableNative;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
bytes extraArgs;
}

mapping(bytes32 => address) /* keyHash */ /* oracle */ internal s_provingKeys;
mapping(bytes32 => bool) /* keyHash */ /* exists */ internal s_provingKeys;
bytes32[] public s_provingKeyHashes;
mapping(uint256 => bytes32) /* requestID */ /* commitment */ public s_requestCommitments;

event ProvingKeyRegistered(bytes32 keyHash, address indexed oracle);
event ProvingKeyRegistered(bytes32 keyHash);
event RandomWordsRequested(
bytes32 indexed keyHash,
uint256 requestId,
Expand Down Expand Up @@ -108,17 +108,16 @@ contract VRFCoordinatorV2PlusUpgradedVersion is

/**
* @notice Registers a proving key to an oracle.
* @param oracle address of the oracle
* @param publicProvingKey key that oracle can use to submit vrf fulfillments
*/
function registerProvingKey(address oracle, uint256[2] calldata publicProvingKey) external onlyOwner {
function registerProvingKey(uint256[2] calldata publicProvingKey) external onlyOwner {
bytes32 kh = hashOfKey(publicProvingKey);
if (s_provingKeys[kh] != address(0)) {
if (s_provingKeys[kh]) {
revert ProvingKeyAlreadyRegistered(kh);
}
s_provingKeys[kh] = oracle;
s_provingKeys[kh] = true;
s_provingKeyHashes.push(kh);
emit ProvingKeyRegistered(kh, oracle);
emit ProvingKeyRegistered(kh);
}

/**
Expand Down Expand Up @@ -346,8 +345,7 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
) internal view returns (Output memory) {
bytes32 keyHash = hashOfKey(proof.pk);
// Only registered proving keys are permitted.
address oracle = s_provingKeys[keyHash];
if (oracle == address(0)) {
if (!s_provingKeys[keyHash]) {
revert NoSuchProvingKey(keyHash);
}
uint256 requestId = uint256(keccak256(abi.encode(keyHash, proof.seed)));
Expand Down Expand Up @@ -426,13 +424,13 @@ contract VRFCoordinatorV2PlusUpgradedVersion is
revert InsufficientBalance();
}
s_subscriptions[rc.subId].nativeBalance -= payment;
s_withdrawableNative[s_provingKeys[output.keyHash]] += payment;
s_withdrawableNative += payment;
} else {
if (s_subscriptions[rc.subId].balance < payment) {
revert InsufficientBalance();
}
s_subscriptions[rc.subId].balance -= payment;
s_withdrawableTokens[s_provingKeys[output.keyHash]] += payment;
s_withdrawableTokens += payment;
}

// Include payment in the event for tracking costs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ contract VRFCoordinatorV2Plus_Migration is BaseTest {

function registerProvingKey() public {
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(UNCOMPRESSED_PUBLIC_KEY);
v1Coordinator.registerProvingKey(OWNER, uncompressedKeyParts);
v1Coordinator_noLink.registerProvingKey(OWNER, uncompressedKeyParts);
v1Coordinator.registerProvingKey(uncompressedKeyParts);
v1Coordinator_noLink.registerProvingKey(uncompressedKeyParts);
}

// note: Call this function via this.getProvingKeyParts to be able to pass memory as calldata and
Expand Down
4 changes: 2 additions & 2 deletions contracts/test/v0.8/foundry/vrf/VRFV2Plus.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ contract VRFV2Plus is BaseTest {
// Should revert when already registered.
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(vrfUncompressedPublicKey);
vm.expectRevert(abi.encodeWithSelector(VRFCoordinatorV2_5.ProvingKeyAlreadyRegistered.selector, vrfKeyHash));
s_testCoordinator.registerProvingKey(LINK_WHALE, uncompressedKeyParts);
s_testCoordinator.registerProvingKey(uncompressedKeyParts);
}

function registerProvingKey() public {
uint256[2] memory uncompressedKeyParts = this.getProvingKeyParts(vrfUncompressedPublicKey);
s_testCoordinator.registerProvingKey(LINK_WHALE, uncompressedKeyParts);
s_testCoordinator.registerProvingKey(uncompressedKeyParts);
}

// note: Call this function via this.getProvingKeyParts to be able to pass memory as calldata and
Expand Down
74 changes: 44 additions & 30 deletions contracts/test/v0.8/foundry/vrf/VRFV2PlusSubscriptionAPI.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -315,25 +315,25 @@ contract VRFV2PlusSubscriptionAPITest is BaseTest {
assertEq(address(s_subscriptionAPI).balance, s_subscriptionAPI.s_totalNativeBalance());
}

function testOracleWithdrawNoLink() public {
function testWithdrawNoLink() public {
// CASE: no link token set
vm.expectRevert(SubscriptionAPI.LinkNotSet.selector);
s_subscriptionAPI.oracleWithdraw(OWNER, 1 ether);
s_subscriptionAPI.withdraw(OWNER);
}

function testOracleWithdrawInsufficientBalance() public {
function testWithdrawInsufficientBalance() public {
// CASE: link token set, trying to withdraw
// more than balance
MockLinkToken linkToken = new MockLinkToken();
s_subscriptionAPI.setLINKAndLINKNativeFeed(address(linkToken), address(0));
assertEq(address(s_subscriptionAPI.LINK()), address(linkToken));

// call oracleWithdraw
// call withdraw
vm.expectRevert(SubscriptionAPI.InsufficientBalance.selector);
s_subscriptionAPI.oracleWithdraw(OWNER, 1 ether);
s_subscriptionAPI.withdraw(OWNER);
}

function testOracleWithdrawSufficientBalanceLinkSet() public {
function testWithdrawSufficientBalanceLinkSet() public {
// CASE: link token set, trying to withdraw
// less than balance
MockLinkToken linkToken = new MockLinkToken();
Expand All @@ -344,58 +344,72 @@ contract VRFV2PlusSubscriptionAPITest is BaseTest {
bool success = linkToken.transfer(address(s_subscriptionAPI), 10 ether);
assertTrue(success, "failed link transfer");

// set the withdrawable tokens of the oracle to be 1 ether
address oracle = makeAddr("oracle");
s_subscriptionAPI.setWithdrawableTokensTestingOnlyXXX(oracle, 1 ether);
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(oracle), 1 ether);
// set the withdrawable tokens of the contract to be 1 ether
s_subscriptionAPI.setWithdrawableTokensTestingOnlyXXX(1 ether);
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(), 1 ether);

// set the total balance to be the same as the link balance for consistency
// (this is not necessary for the test, but just to be sane)
s_subscriptionAPI.setTotalBalanceTestingOnlyXXX(10 ether);

// call oracleWithdraw from oracle address
changePrank(oracle);
s_subscriptionAPI.oracleWithdraw(oracle, 1 ether);
// assert link balance of oracle
assertEq(linkToken.balanceOf(oracle), 1 ether, "oracle link balance incorrect");
// call Withdraw from owner address
uint256 ownerBalance = linkToken.balanceOf(OWNER);
changePrank(OWNER);
s_subscriptionAPI.withdraw(OWNER);
// assert link balance of owner
assertEq(linkToken.balanceOf(OWNER) - ownerBalance, 1 ether, "owner link balance incorrect");
// assert state of subscription api
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(oracle), 0, "oracle withdrawable tokens incorrect");
assertEq(s_subscriptionAPI.getWithdrawableTokensTestingOnlyXXX(), 0, "owner withdrawable tokens incorrect");
// assert that total balance is changed by the withdrawn amount
assertEq(s_subscriptionAPI.s_totalBalance(), 9 ether, "total balance incorrect");
}

function testOracleWithdrawNativeInsufficientBalance() public {
function testWithdrawNativeInsufficientBalance() public {
// CASE: trying to withdraw more than balance
// should revert with InsufficientBalance

// call oracleWithdrawNative
// call WithdrawNative
changePrank(OWNER);
vm.expectRevert(SubscriptionAPI.InsufficientBalance.selector);
s_subscriptionAPI.oracleWithdrawNative(payable(OWNER), 1 ether);
s_subscriptionAPI.withdrawNative(payable(OWNER));
}

function testWithdrawLinkInvalidOwner() public {
address invalidAddress = makeAddr("invalidAddress");
changePrank(invalidAddress);
vm.expectRevert("Only callable by owner");
s_subscriptionAPI.withdraw(payable(OWNER));
}

function testOracleWithdrawNativeSufficientBalance() public {
function testWithdrawNativeInvalidOwner() public {
address invalidAddress = makeAddr("invalidAddress");
changePrank(invalidAddress);
vm.expectRevert("Only callable by owner");
s_subscriptionAPI.withdrawNative(payable(OWNER));
}

function testWithdrawNativeSufficientBalance() public {
// CASE: trying to withdraw less than balance
// should withdraw successfully

// transfer 10 ether to the contract to withdraw
vm.deal(address(s_subscriptionAPI), 10 ether);

// set the withdrawable eth of the oracle to be 1 ether
address oracle = makeAddr("oracle");
s_subscriptionAPI.setWithdrawableNativeTestingOnlyXXX(oracle, 1 ether);
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(oracle), 1 ether);
// set the withdrawable eth of the contract to be 1 ether
s_subscriptionAPI.setWithdrawableNativeTestingOnlyXXX(1 ether);
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(), 1 ether);

// set the total balance to be the same as the eth balance for consistency
// (this is not necessary for the test, but just to be sane)
s_subscriptionAPI.setTotalNativeBalanceTestingOnlyXXX(10 ether);

// call oracleWithdrawNative from oracle address
changePrank(oracle);
s_subscriptionAPI.oracleWithdrawNative(payable(oracle), 1 ether);
// assert native balance of oracle
assertEq(address(oracle).balance, 1 ether, "oracle native balance incorrect");
// call WithdrawNative from owner address
changePrank(OWNER);
s_subscriptionAPI.withdrawNative(payable(OWNER));
// assert native balance
assertEq(address(OWNER).balance, 1 ether, "owner native balance incorrect");
// assert state of subscription api
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(oracle), 0, "oracle withdrawable native incorrect");
assertEq(s_subscriptionAPI.getWithdrawableNativeTestingOnlyXXX(), 0, "owner withdrawable native incorrect");
// assert that total balance is changed by the withdrawn amount
assertEq(s_subscriptionAPI.s_totalNativeBalance(), 9 ether, "total native balance incorrect");
}
Expand Down
Loading
Loading