diff --git a/VERSION b/VERSION index 227cea21..4a36342f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.0.0 +3.0.0 diff --git a/contracts/Marionette.sol b/contracts/Marionette.sol index eef15ad5..f21fd368 100644 --- a/contracts/Marionette.sol +++ b/contracts/Marionette.sol @@ -32,12 +32,6 @@ contract Marionette is IMarionette, AccessControlEnumerableUpgradeable { using AddressUpgradeable for address; using AddressUpgradeable for address payable; - struct FunctionCall { - address receiver; - uint value; - bytes data; - } - bytes32 public constant IMA_ROLE = keccak256("IMA_ROLE"); bytes32 public constant PUPPETEER_ROLE = keccak256("PUPPETEER_ROLE"); string public constant ACCESS_VIOLATION = "Access violation"; @@ -83,16 +77,22 @@ contract Marionette is IMarionette, AccessControlEnumerableUpgradeable { address sender, bytes calldata data ) - external - override + external + override { require(hasRole(IMA_ROLE, msg.sender), "Sender is not IMA"); require(hasRole(PUPPETEER_ROLE, sender), ACCESS_VIOLATION); - FunctionCall memory functionCall = _parseFunctionCall(data); + FunctionCall[] memory functionCalls = _parseFunctionCalls(data); - bytes memory output = _doCall(payable(functionCall.receiver), functionCall.value, functionCall.data); - emit FunctionCallResult(output); + for (uint i = 0; i < functionCalls.length; ++i) { + bytes memory output = _doCall( + payable(functionCalls[i].receiver), + functionCalls[i].value, + functionCalls[i].data + ); + emit FunctionCallResult(output); + } } function execute( @@ -107,12 +107,34 @@ contract Marionette is IMarionette, AccessControlEnumerableUpgradeable { { require(hasRole(PUPPETEER_ROLE, msg.sender), ACCESS_VIOLATION); + if (msg.value > 0) { + emit EtherReceived(msg.sender, msg.value); + } + return _doCall(target, value, data); } + function executeMultiple(FunctionCall[] calldata functionCalls) external payable override returns (bytes[] memory) { + require(hasRole(PUPPETEER_ROLE, msg.sender), ACCESS_VIOLATION); + + if (msg.value > 0) { + emit EtherReceived(msg.sender, msg.value); + } + + bytes[] memory results = new bytes[](functionCalls.length); + for (uint i = 0; i < functionCalls.length; ++i) { + results[i] = _doCall(payable(functionCalls[i].receiver), functionCalls[i].value, functionCalls[i].data); + } + return results; + } + function sendSFuel(address payable target, uint value) external payable override { require(hasRole(PUPPETEER_ROLE, msg.sender), ACCESS_VIOLATION); + if (msg.value > 0) { + emit EtherReceived(msg.sender, msg.value); + } + _doCall(target, value, "0x"); } @@ -133,17 +155,25 @@ contract Marionette is IMarionette, AccessControlEnumerableUpgradeable { override returns (bytes memory) { - return abi.encode(receiver, value, data); + FunctionCall[] memory functionCalls = new FunctionCall[](1); + functionCalls[0] = FunctionCall({receiver: receiver, value: value, data: data}); + return encodeFunctionCalls(FunctionCall[](functionCalls)); + } + + function encodeFunctionCalls( + FunctionCall[] memory functionCalls + ) + public + pure + override + returns (bytes memory) + { + return abi.encode(functionCalls); } // private function _doCall(address payable target, uint value, bytes memory data) private returns (bytes memory) { - - if (msg.value > 0) { - emit EtherReceived(msg.sender, msg.value); - } - if (value > 0) { emit EtherSent(target, value); } @@ -161,7 +191,7 @@ contract Marionette is IMarionette, AccessControlEnumerableUpgradeable { } } - function _parseFunctionCall(bytes calldata data) private pure returns (FunctionCall memory functionCall) { - (functionCall.receiver, functionCall.value, functionCall.data) = abi.decode(data, (address, uint, bytes)); + function _parseFunctionCalls(bytes calldata data) private pure returns (FunctionCall[] memory functionCalls) { + return abi.decode(data, (FunctionCall[])); } } \ No newline at end of file diff --git a/contracts/interfaces/IMarionette.sol b/contracts/interfaces/IMarionette.sol index 6947024d..95733be8 100644 --- a/contracts/interfaces/IMarionette.sol +++ b/contracts/interfaces/IMarionette.sol @@ -25,15 +25,28 @@ import "@skalenetwork/ima-interfaces/IMessageReceiver.sol"; interface IMarionette is IMessageReceiver { + struct FunctionCall { + address receiver; + uint value; + bytes data; + } + receive() external payable; function initialize(address owner, address ima) external; function execute(address payable target, uint value, bytes calldata data) external payable returns (bytes memory); + function executeMultiple(FunctionCall[] calldata functionCalls) external payable returns (bytes[] memory); function setVersion(string calldata newVersion) external; function sendSFuel(address payable target, uint value) external payable; function encodeFunctionCall( address receiver, uint value, bytes calldata data + ) + external + pure + returns (bytes memory); + function encodeFunctionCalls( + FunctionCall[] calldata functionCalls ) external pure diff --git a/test/Marionette.ts b/test/Marionette.ts index c031a5fc..f8092fb7 100644 --- a/test/Marionette.ts +++ b/test/Marionette.ts @@ -86,7 +86,7 @@ describe("Marionette", () => { const uintValue = 5; const stringValue = "Hello from D2"; - it ("should allow owner to call contract", async () => { + it ("should allow owner to call a contract", async () => { const transaction = await marionette.execute( target.address, amount, @@ -104,7 +104,33 @@ describe("Marionette", () => { await target.sendSFuel(owner.address, amount); }); - it ("should not allow everyone to call contract", async () => { + it ("should allow owner to do multiple calls to a contract", async () => { + const call1 = { + receiver: target.address, + value: 1, + data: target.interface.encodeFunctionData( + "targetFunction", + [1, "call1"] + ) + }; + const call2 = { + receiver: target.address, + value: 2, + data: target.interface.encodeFunctionData( + "targetFunction", + [2, "call2"] + ) + }; + + const transaction = await marionette.executeMultiple([call1, call2], {value: 3}); + await transaction.should.emit(marionette, "EtherReceived").withArgs(owner.address, 3); + await transaction.should.emit(marionette, "EtherSent").withArgs(target.address, 1); + await transaction.should.emit(marionette, "EtherSent").withArgs(target.address, 2); + await transaction.should.emit(target, "ExecutionResult").withArgs(1, "call1"); + await transaction.should.emit(target, "ExecutionResult").withArgs(2, "call2"); + }); + + it ("should not allow everyone to call a contract", async () => { await marionette.connect(hacker).execute(target.address, 0, "0x") .should.be.eventually.rejectedWith("Access violation"); }); @@ -126,6 +152,7 @@ describe("Marionette", () => { (await marionette.version()).should.be.equal("nice"); }); + describe("Calls from IMA", () => { it ("should allow IMA to trigger function call", async () => {