Skip to content

Commit

Permalink
Integrate message validation hook to multi onramp (#1060)
Browse files Browse the repository at this point in the history
## Motivation

Follow up of #916. 

## Solution

- Replaces the `EVM2EVMMultiOnRamp` `AggregateRateLimiter` inheritance
by a call to an `IMessageValidator` hook contract to validate messages.
- Implements the `MultiAggregateRateLimiter.onOutgoingMessage()`
function.
  • Loading branch information
RayXpub authored Jun 24, 2024
1 parent 1652778 commit 8d34b38
Show file tree
Hide file tree
Showing 16 changed files with 594 additions and 758 deletions.
211 changes: 109 additions & 102 deletions contracts/gas-snapshots/ccip.gas-snapshot

Large diffs are not rendered by default.

37 changes: 25 additions & 12 deletions contracts/src/v0.8/ccip/MultiAggregateRateLimiter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,41 @@ contract MultiAggregateRateLimiter is IMessageInterceptor, AuthorizedCallers {

/// @inheritdoc IMessageInterceptor
function onInboundMessage(Client.Any2EVMMessage memory message) external onlyAuthorizedCallers {
uint64 remoteChainSelector = message.sourceChainSelector;
RateLimiter.TokenBucket storage tokenBucket = _getTokenBucket(remoteChainSelector, false);
_applyRateLimit(message.sourceChainSelector, message.destTokenAmounts, false);
}

/// @inheritdoc IMessageInterceptor
function onOutboundMessage(
uint64 destChainSelector,
Client.EVM2AnyMessage calldata message
) external onlyAuthorizedCallers {
_applyRateLimit(destChainSelector, message.tokenAmounts, true);
}

/// @notice Applies the rate limit to the token bucket if enabled
/// @param remoteChainSelector The remote chain selector
/// @param tokenAmounts The tokens and amounts to rate limit
/// @param isOutgoingLane if set to true, fetches the bucket for the outgoing message lane (OnRamp).
function _applyRateLimit(
uint64 remoteChainSelector,
Client.EVMTokenAmount[] memory tokenAmounts,
bool isOutgoingLane
) private {
RateLimiter.TokenBucket storage tokenBucket = _getTokenBucket(remoteChainSelector, isOutgoingLane);

// Skip rate limiting if it is disabled
if (tokenBucket.isEnabled) {
uint256 value;
Client.EVMTokenAmount[] memory destTokenAmounts = message.destTokenAmounts;
for (uint256 i = 0; i < destTokenAmounts.length; ++i) {
if (s_rateLimitedTokensLocalToRemote[remoteChainSelector].contains(destTokenAmounts[i].token)) {
value += _getTokenValue(destTokenAmounts[i]);
for (uint256 i = 0; i < tokenAmounts.length; ++i) {
if (s_rateLimitedTokensLocalToRemote[remoteChainSelector].contains(tokenAmounts[i].token)) {
value += _getTokenValue(tokenAmounts[i]);
}
}

// Rate limit on aggregated token value
if (value > 0) tokenBucket._consume(value, address(0));
}
}

/// @inheritdoc IMessageInterceptor
function onOutboundMessage(Client.EVM2AnyMessage memory message, uint64 destChainSelector) external {
// TODO: to be implemented (assuming the same rate limiter states are shared for inbound and outbound messages)
}

/// @param remoteChainSelector chain selector to retrieve token bucket for
/// @param isOutboundLane if set to true, fetches the bucket for the outbound message lane (OnRamp).
/// Otherwise fetches for the inbound message lane (OffRamp).
Expand Down
4 changes: 2 additions & 2 deletions contracts/src/v0.8/ccip/interfaces/IMessageInterceptor.sol
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ interface IMessageInterceptor {
function onInboundMessage(Client.Any2EVMMessage memory message) external;

/// @notice Intercepts & validates the given OnRamp message. Reverts on validation failure
/// @param message to validate
/// @param destChainSelector remote destination chain selector where the message is being sent to
function onOutboundMessage(Client.EVM2AnyMessage memory message, uint64 destChainSelector) external;
/// @param message to validate
function onOutboundMessage(uint64 destChainSelector, Client.EVM2AnyMessage memory message) external;
}
48 changes: 18 additions & 30 deletions contracts/src/v0.8/ccip/onRamp/EVM2EVMMultiOnRamp.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ pragma solidity 0.8.24;

import {ITypeAndVersion} from "../../shared/interfaces/ITypeAndVersion.sol";
import {IEVM2AnyOnRampClient} from "../interfaces/IEVM2AnyOnRampClient.sol";
import {IMessageInterceptor} from "../interfaces/IMessageInterceptor.sol";
import {INonceManager} from "../interfaces/INonceManager.sol";
import {IPoolV1} from "../interfaces/IPool.sol";
import {IPriceRegistry} from "../interfaces/IPriceRegistry.sol";
import {IRMN} from "../interfaces/IRMN.sol";
import {ITokenAdminRegistry} from "../interfaces/ITokenAdminRegistry.sol";

import {AggregateRateLimiter} from "../AggregateRateLimiter.sol";
import {OwnerIsCreator} from "../../shared/access/OwnerIsCreator.sol";
import {Client} from "../libraries/Client.sol";
import {Internal} from "../libraries/Internal.sol";
import {Pool} from "../libraries/Pool.sol";
import {RateLimiter} from "../libraries/RateLimiter.sol";
import {USDPriceWith18Decimals} from "../libraries/USDPriceWith18Decimals.sol";

import {IERC20} from "../../vendor/openzeppelin-solidity/v4.8.3/contracts/token/ERC20/IERC20.sol";
Expand All @@ -22,10 +22,11 @@ import {SafeERC20} from "../../vendor/openzeppelin-solidity/v4.8.3/contracts/tok
/// @notice The EVM2EVMMultiOnRamp is a contract that handles lane-specific fee logic
/// @dev The EVM2EVMMultiOnRamp, MultiCommitStore and EVM2EVMMultiOffRamp form an xchain upgradeable unit. Any change to one of them
/// results an onchain upgrade of all 3.
contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, ITypeAndVersion {
contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, ITypeAndVersion, OwnerIsCreator {
using SafeERC20 for IERC20;
using USDPriceWith18Decimals for uint224;

error CannotSendZeroTokens();
error InvalidExtraArgsTag();
error OnlyCallableByOwnerOrAdmin();
error MessageTooLarge(uint256 maxSize, uint256 actualSize);
Expand All @@ -38,13 +39,13 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
error InvalidConfig();
error CursedByRMN(uint64 sourceChainSelector);
error NotAFeeToken(address token);
error CannotSendZeroTokens();
error SourceTokenDataTooLarge(address token);
error GetSupportedTokensFunctionalityRemovedCheckAdminRegistry();
error InvalidDestChainConfig(uint64 destChainSelector);
error DestinationChainNotEnabled(uint64 destChainSelector);
error InvalidDestBytesOverhead(address token, uint32 destBytesOverhead);

event AdminSet(address newAdmin);
event ConfigSet(StaticConfig staticConfig, DynamicConfig dynamicConfig);
event FeePaid(address indexed feeToken, uint256 feeValueJuels);
event FeeTokenWithdrawn(address indexed feeAggregator, address indexed feeToken, uint256 amount);
Expand Down Expand Up @@ -75,6 +76,7 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
struct DynamicConfig {
address router; // Router address
address priceRegistry; // Price registry address
address messageValidator; // Optional message validator to validate outbound messages (zero address = no validator)
address feeAggregator; // Fee aggregator address
}

Expand All @@ -94,7 +96,6 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
uint32 destGasOverhead; // │ Gas charged to execute the token transfer on the destination chain
// │ Extra data availability bytes that are returned from the source pool and sent
uint32 destBytesOverhead; // │ to the destination pool. Must be >= Pool.CCIP_LOCK_OR_BURN_V1_RET_BYTES
bool aggregateRateLimitEnabled; // │ Whether this transfer token is to be included in Aggregate Rate Limiting
bool isEnabled; // ─────────────────╯ Whether this token has custom transfer fees
}

Expand Down Expand Up @@ -204,10 +205,9 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
StaticConfig memory staticConfig,
DynamicConfig memory dynamicConfig,
DestChainConfigArgs[] memory destChainConfigArgs,
RateLimiter.Config memory rateLimiterConfig,
PremiumMultiplierWeiPerEthArgs[] memory premiumMultiplierWeiPerEthArgs,
TokenTransferFeeConfigArgs[] memory tokenTransferFeeConfigArgs
) AggregateRateLimiter(rateLimiterConfig) {
) {
if (
staticConfig.linkToken == address(0) || staticConfig.chainSelector == 0 || staticConfig.rmnProxy == address(0)
|| staticConfig.nonceManager == address(0) || staticConfig.tokenAdminRegistry == address(0)
Expand Down Expand Up @@ -254,6 +254,9 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
// There should be no state changes after external call to TokenPools.
for (uint256 i = 0; i < newMessage.tokenAmounts.length; ++i) {
Client.EVMTokenAmount memory tokenAndAmount = message.tokenAmounts[i];

if (tokenAndAmount.amount == 0) revert CannotSendZeroTokens();

IPoolV1 sourcePool = getPoolBySourceToken(destChainSelector, IERC20(tokenAndAmount.token));
// We don't have to check if it supports the pool version in a non-reverting way here because
// if we revert here, there is no effect on CCIP. Therefore we directly call the supportsInterface
Expand Down Expand Up @@ -333,15 +336,13 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType

// Only check token value if there are tokens
if (numberOfTokens > 0) {
uint256 value;
for (uint256 i = 0; i < numberOfTokens; ++i) {
if (message.tokenAmounts[i].amount == 0) revert CannotSendZeroTokens();
if (s_tokenTransferFeeConfig[destChainSelector][message.tokenAmounts[i].token].aggregateRateLimitEnabled) {
value += _getTokenValue(message.tokenAmounts[i], IPriceRegistry(s_dynamicConfig.priceRegistry));
address messageValidator = s_dynamicConfig.messageValidator;
if (messageValidator != address(0)) {
try IMessageInterceptor(messageValidator).onOutboundMessage(destChainSelector, message) {}
catch (bytes memory err) {
revert IMessageInterceptor.MessageValidationError(err);
}
}
// Rate limit on aggregated token value
if (value > 0) _rateLimitValue(value);
}

uint256 msgFeeJuels;
Expand Down Expand Up @@ -664,8 +665,7 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType

/// @notice Updates the destination chain specific config.
/// @param destChainConfigArgs Array of source chain specific configs.
function applyDestChainConfigUpdates(DestChainConfigArgs[] memory destChainConfigArgs) external {
_onlyOwnerOrAdmin();
function applyDestChainConfigUpdates(DestChainConfigArgs[] memory destChainConfigArgs) external onlyOwner {
_applyDestChainConfigUpdates(destChainConfigArgs);
}

Expand Down Expand Up @@ -727,8 +727,7 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
/// @param premiumMultiplierWeiPerEthArgs Array of PremiumMultiplierWeiPerEthArgs structs.
function applyPremiumMultiplierWeiPerEthUpdates(
PremiumMultiplierWeiPerEthArgs[] memory premiumMultiplierWeiPerEthArgs
) external {
_onlyOwnerOrAdmin();
) external onlyOwner {
_applyPremiumMultiplierWeiPerEthUpdates(premiumMultiplierWeiPerEthArgs);
}

Expand Down Expand Up @@ -761,8 +760,7 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
function applyTokenTransferFeeConfigUpdates(
TokenTransferFeeConfigArgs[] memory tokenTransferFeeConfigArgs,
TokenTransferFeeConfigRemoveArgs[] memory tokensToUseDefaultFeeConfigs
) external {
_onlyOwnerOrAdmin();
) external onlyOwner {
_applyTokenTransferFeeConfigUpdates(tokenTransferFeeConfigArgs, tokensToUseDefaultFeeConfigs);
}

Expand Down Expand Up @@ -816,14 +814,4 @@ contract EVM2EVMMultiOnRamp is IEVM2AnyOnRampClient, AggregateRateLimiter, IType
}
}
}

// ================================================================
// │ Access │
// ================================================================

/// @dev Require that the sender is the owner or the fee admin
/// Not a modifier to save on contract size
function _onlyOwnerOrAdmin() internal view {
if (msg.sender != owner() && msg.sender != s_admin) revert OnlyCallableByOwnerOrAdmin();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ contract EVM2EVMMultiOnRampHelper is EVM2EVMMultiOnRamp, IgnoreContractSize {
StaticConfig memory staticConfig,
DynamicConfig memory dynamicConfig,
DestChainConfigArgs[] memory destChainConfigs,
RateLimiter.Config memory rateLimiterConfig,
PremiumMultiplierWeiPerEthArgs[] memory premiumMultiplierWeiPerEthArgs,
TokenTransferFeeConfigArgs[] memory tokenTransferFeeConfigArgs
)
EVM2EVMMultiOnRamp(
staticConfig,
dynamicConfig,
destChainConfigs,
rateLimiterConfig,
premiumMultiplierWeiPerEthArgs,
tokenTransferFeeConfigArgs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import {IMessageInterceptor} from "../../interfaces/IMessageInterceptor.sol";
import {Client} from "../../libraries/Client.sol";

contract MessageInterceptorHelper is IMessageInterceptor {
error IncomingMessageValidationError(bytes errorReason);

mapping(bytes32 messageId => bool isInvalid) internal s_invalidMessageIds;

constructor() {}
Expand All @@ -18,13 +16,15 @@ contract MessageInterceptorHelper is IMessageInterceptor {
/// @inheritdoc IMessageInterceptor
function onInboundMessage(Client.Any2EVMMessage memory message) external view {
if (s_invalidMessageIds[message.messageId]) {
revert IncomingMessageValidationError(bytes("Invalid message"));
revert MessageValidationError(bytes("Invalid message"));
}
}

/// @inheritdoc IMessageInterceptor
function onOutboundMessage(Client.EVM2AnyMessage memory, uint64) external pure {
// TODO: to be implemented
function onOutboundMessage(uint64, Client.EVM2AnyMessage calldata message) external view {
if (s_invalidMessageIds[keccak256(abi.encode(message))]) {
revert MessageValidationError(bytes("Invalid message"));
}
return;
}
}
37 changes: 14 additions & 23 deletions contracts/src/v0.8/ccip/test/offRamp/EVM2EVMMultiOffRamp.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
pragma solidity 0.8.24;

import {ICommitStore} from "../../interfaces/ICommitStore.sol";

import {IMessageInterceptor} from "../../interfaces/IMessageInterceptor.sol";
import {IPriceRegistry} from "../../interfaces/IPriceRegistry.sol";
import {IRMN} from "../../interfaces/IRMN.sol";
import {ITokenAdminRegistry} from "../../interfaces/ITokenAdminRegistry.sol";
Expand All @@ -11,7 +11,6 @@ import {CallWithExactGas} from "../../../shared/call/CallWithExactGas.sol";
import {PriceRegistry} from "../../PriceRegistry.sol";
import {RMN} from "../../RMN.sol";
import {Router} from "../../Router.sol";
import {IMessageInterceptor} from "../../interfaces/IMessageInterceptor.sol";
import {Client} from "../../libraries/Client.sol";
import {Internal} from "../../libraries/Internal.sol";
import {MerkleMultiProof} from "../../libraries/MerkleMultiProof.sol";
Expand Down Expand Up @@ -269,7 +268,7 @@ contract EVM2EVMMultiOffRamp_setDynamicConfig is EVM2EVMMultiOffRampSetup {
function test_SetDynamicConfigWithValidator_Success() public {
EVM2EVMMultiOffRamp.DynamicConfig memory dynamicConfig =
_generateDynamicMultiOffRampConfig(USER_3, address(s_priceRegistry));
dynamicConfig.messageValidator = address(s_messageValidator);
dynamicConfig.messageValidator = address(s_inboundMessageValidator);

vm.expectEmit();
emit EVM2EVMMultiOffRamp.DynamicConfigSet(dynamicConfig);
Expand Down Expand Up @@ -1204,7 +1203,7 @@ contract EVM2EVMMultiOffRamp_executeSingleMessage is EVM2EVMMultiOffRampSetup {
function test_executeSingleMessage_WithValidation_Success() public {
vm.stopPrank();
vm.startPrank(OWNER);
_enableMessageValidator();
_enableInboundMessageValidator();
vm.startPrank(address(s_offRamp));
Internal.EVM2EVMMessage memory message =
_generateAny2EVMMessageNoTokens(SOURCE_CHAIN_SELECTOR_1, ON_RAMP_ADDRESS_1, 1);
Expand Down Expand Up @@ -1271,17 +1270,15 @@ contract EVM2EVMMultiOffRamp_executeSingleMessage is EVM2EVMMultiOffRampSetup {
function test_executeSingleMessage_WithFailingValidation_Revert() public {
vm.stopPrank();
vm.startPrank(OWNER);
_enableMessageValidator();
_enableInboundMessageValidator();
vm.startPrank(address(s_offRamp));
Internal.EVM2EVMMessage memory message =
_generateAny2EVMMessageNoTokens(SOURCE_CHAIN_SELECTOR_1, ON_RAMP_ADDRESS_1, 1);
s_messageValidator.setMessageIdValidationState(message.messageId, true);
s_inboundMessageValidator.setMessageIdValidationState(message.messageId, true);
vm.expectRevert(
abi.encodeWithSelector(
IMessageInterceptor.MessageValidationError.selector,
abi.encodeWithSelector(
MessageInterceptorHelper.IncomingMessageValidationError.selector, bytes("Invalid message")
)
abi.encodeWithSelector(IMessageInterceptor.MessageValidationError.selector, bytes("Invalid message"))
)
);
s_offRamp.executeSingleMessage(message, new bytes[](message.tokenAmounts.length));
Expand All @@ -1290,7 +1287,7 @@ contract EVM2EVMMultiOffRamp_executeSingleMessage is EVM2EVMMultiOffRampSetup {
function test_executeSingleMessage_WithFailingValidationNoRouterCall_Revert() public {
vm.stopPrank();
vm.startPrank(OWNER);
_enableMessageValidator();
_enableInboundMessageValidator();
vm.startPrank(address(s_offRamp));

Internal.EVM2EVMMessage memory message =
Expand All @@ -1301,13 +1298,11 @@ contract EVM2EVMMultiOffRamp_executeSingleMessage is EVM2EVMMultiOffRampSetup {
message.receiver = address(newReceiver);
message.messageId = Internal._hash(message, s_offRamp.metadataHash(SOURCE_CHAIN_SELECTOR_1, ON_RAMP_ADDRESS_1));

s_messageValidator.setMessageIdValidationState(message.messageId, true);
s_inboundMessageValidator.setMessageIdValidationState(message.messageId, true);
vm.expectRevert(
abi.encodeWithSelector(
IMessageInterceptor.MessageValidationError.selector,
abi.encodeWithSelector(
MessageInterceptorHelper.IncomingMessageValidationError.selector, bytes("Invalid message")
)
abi.encodeWithSelector(IMessageInterceptor.MessageValidationError.selector, bytes("Invalid message"))
)
);
s_offRamp.executeSingleMessage(message, new bytes[](message.tokenAmounts.length));
Expand Down Expand Up @@ -2047,7 +2042,7 @@ contract EVM2EVMMultiOffRamp_execute is EVM2EVMMultiOffRampSetup {
}

function test_MultipleReportsWithPartialValidationFailures_Success() public {
_enableMessageValidator();
_enableInboundMessageValidator();

Internal.EVM2EVMMessage[] memory messages1 = new Internal.EVM2EVMMessage[](2);
Internal.EVM2EVMMessage[] memory messages2 = new Internal.EVM2EVMMessage[](1);
Expand All @@ -2060,8 +2055,8 @@ contract EVM2EVMMultiOffRamp_execute is EVM2EVMMultiOffRampSetup {
reports[0] = _generateReportFromMessages(SOURCE_CHAIN_SELECTOR_1, messages1);
reports[1] = _generateReportFromMessages(SOURCE_CHAIN_SELECTOR_1, messages2);

s_messageValidator.setMessageIdValidationState(messages1[0].messageId, true);
s_messageValidator.setMessageIdValidationState(messages2[0].messageId, true);
s_inboundMessageValidator.setMessageIdValidationState(messages1[0].messageId, true);
s_inboundMessageValidator.setMessageIdValidationState(messages2[0].messageId, true);

vm.expectEmit();
emit EVM2EVMMultiOffRamp.ExecutionStateChanged(
Expand All @@ -2071,9 +2066,7 @@ contract EVM2EVMMultiOffRamp_execute is EVM2EVMMultiOffRampSetup {
Internal.MessageExecutionState.FAILURE,
abi.encodeWithSelector(
IMessageInterceptor.MessageValidationError.selector,
abi.encodeWithSelector(
MessageInterceptorHelper.IncomingMessageValidationError.selector, bytes("Invalid message")
)
abi.encodeWithSelector(IMessageInterceptor.MessageValidationError.selector, bytes("Invalid message"))
)
);

Expand All @@ -2094,9 +2087,7 @@ contract EVM2EVMMultiOffRamp_execute is EVM2EVMMultiOffRampSetup {
Internal.MessageExecutionState.FAILURE,
abi.encodeWithSelector(
IMessageInterceptor.MessageValidationError.selector,
abi.encodeWithSelector(
MessageInterceptorHelper.IncomingMessageValidationError.selector, bytes("Invalid message")
)
abi.encodeWithSelector(IMessageInterceptor.MessageValidationError.selector, bytes("Invalid message"))
)
);

Expand Down
Loading

0 comments on commit 8d34b38

Please sign in to comment.