Skip to content

Commit

Permalink
Merge pull request #547 from lidofinance/develop
Browse files Browse the repository at this point in the history
Master merge
  • Loading branch information
F4ever authored Oct 30, 2024
2 parents 9d33bd7 + a63f37d commit c12a0a2
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 28 deletions.
31 changes: 28 additions & 3 deletions src/modules/accounting/accounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections import defaultdict
from time import sleep

from hexbytes import HexBytes
from web3.exceptions import ContractCustomError
from web3.types import Wei

from src import variables
Expand All @@ -19,6 +21,7 @@
FinalizationShareRate,
ValidatorsCount,
ValidatorsBalance,
AccountingProcessingState,
)
from src.metrics.prometheus.accounting import (
ACCOUNTING_IS_BUNKER,
Expand All @@ -27,9 +30,10 @@
ACCOUNTING_WITHDRAWAL_VAULT_BALANCE_WEI
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.submodules.types import ZERO_HASH
from src.providers.execution.contracts.accounting_oracle import AccountingOracleContract
from src.services.validator_state import LidoValidatorStateService
from src.modules.submodules.consensus import ConsensusModule
from src.modules.submodules.consensus import ConsensusModule, InitialEpochIsYetToArriveRevert
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.services.withdrawal import Withdrawal
from src.services.bunker import BunkerService
Expand Down Expand Up @@ -116,13 +120,13 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
# Consensus module: if contract got report data (second phase)
processing_state = self.report_contract.get_processing_state(blockstamp.block_hash)
processing_state = self._get_processing_state(blockstamp)
logger.debug({'msg': 'Check if main data was submitted.', 'value': processing_state.main_data_submitted})
return processing_state.main_data_submitted

def can_submit_extra_data(self, blockstamp: BlockStamp) -> bool:
"""Check if Oracle can submit extra data. Can only be submitted after second phase."""
processing_state = self.report_contract.get_processing_state(blockstamp.block_hash)
processing_state = self._get_processing_state(blockstamp)
return processing_state.main_data_submitted and not processing_state.extra_data_submitted

def is_contract_reportable(self, blockstamp: BlockStamp) -> bool:
Expand All @@ -140,6 +144,27 @@ def is_reporting_allowed(self, blockstamp: ReferenceBlockStamp) -> bool:
logger.warning({'msg': '!' * 50})
return ALLOW_REPORTING_IN_BUNKER_MODE

def _get_processing_state(self, blockstamp: BlockStamp) -> AccountingProcessingState:
try:
return self.report_contract.get_processing_state(blockstamp.block_hash)
except ContractCustomError as revert:
if revert.data != InitialEpochIsYetToArriveRevert:
raise revert

frame = self.get_initial_or_current_frame(blockstamp)

return AccountingProcessingState(
current_frame_ref_slot=frame.ref_slot,
processing_deadline_time=frame.report_processing_deadline_slot,
main_data_hash=HexBytes(ZERO_HASH),
main_data_submitted=False,
extra_data_hash=HexBytes(ZERO_HASH),
extra_data_format=0,
extra_data_submitted=False,
extra_data_items_count=0,
extra_data_items_submitted=0,
)

# ---------------------------------------- Build report ----------------------------------------
def _calculate_report(self, blockstamp: ReferenceBlockStamp):
consensus_version = self.report_contract.get_consensus_version(blockstamp.block_hash)
Expand Down
4 changes: 2 additions & 2 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
last_ref_slot = self.w3.csm.get_csm_last_processing_ref_slot(blockstamp)
ref_slot = self.get_current_frame(blockstamp).ref_slot
ref_slot = self.get_initial_or_current_frame(blockstamp).ref_slot
return last_ref_slot == ref_slot

def is_contract_reportable(self, blockstamp: BlockStamp) -> bool:
Expand Down Expand Up @@ -370,7 +370,7 @@ def current_frame_range(self, blockstamp: BlockStamp) -> tuple[EpochNumber, Epoc

# NOTE: before the initial slot the contract can't return current frame
if blockstamp.slot_number > initial_ref_slot:
r_ref_slot = self.get_current_frame(blockstamp).ref_slot
r_ref_slot = self.get_initial_or_current_frame(blockstamp).ref_slot

# We are between reports, next report slot didn't happen yet. Predicting the next ref slot for the report
# to calculate epochs range to collect the data.
Expand Down
27 changes: 24 additions & 3 deletions src/modules/ejector/ejector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import reduce

from more_itertools import ilen
from web3.exceptions import ContractCustomError
from web3.types import Wei

from src.constants import (
Expand All @@ -18,9 +19,10 @@
)
from src.metrics.prometheus.duration_meter import duration_meter
from src.modules.ejector.data_encode import encode_data
from src.modules.ejector.types import ReportData
from src.modules.submodules.consensus import ConsensusModule
from src.modules.ejector.types import ReportData, EjectorProcessingState
from src.modules.submodules.consensus import ConsensusModule, InitialEpochIsYetToArriveRevert
from src.modules.submodules.oracle_module import BaseModule, ModuleExecuteDelay
from src.modules.submodules.types import ZERO_HASH
from src.providers.consensus.types import Validator
from src.providers.execution.contracts.exit_bus_oracle import ExitBusOracleContract
from src.services.exit_order.iterator import ExitOrderIterator
Expand Down Expand Up @@ -315,8 +317,27 @@ def _get_total_active_validators(self, blockstamp: ReferenceBlockStamp) -> int:
return total_active_validators

def is_main_data_submitted(self, blockstamp: BlockStamp) -> bool:
processing_state = self.report_contract.get_processing_state(blockstamp.block_hash)
processing_state = self._get_processing_state(blockstamp)
return processing_state.data_submitted

def is_contract_reportable(self, blockstamp: BlockStamp) -> bool:
return not self.is_main_data_submitted(blockstamp)

def _get_processing_state(self, blockstamp: BlockStamp) -> EjectorProcessingState:
try:
return self.report_contract.get_processing_state(blockstamp.block_hash)
except ContractCustomError as revert:
if revert.data != InitialEpochIsYetToArriveRevert:
raise revert

frame = self.get_initial_or_current_frame(blockstamp)

return EjectorProcessingState(
current_frame_ref_slot=frame.ref_slot,
processing_deadline_time=frame.report_processing_deadline_slot,
data_hash=ZERO_HASH,
data_submitted=False,
data_format=0,
requests_count=0,
requests_submitted=0,
)
42 changes: 30 additions & 12 deletions src/modules/submodules/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from eth_abi import encode
from eth_typing import BlockIdentifier
from hexbytes import HexBytes
from web3.exceptions import ContractCustomError

from src import variables
from src.metrics.prometheus.basic import ORACLE_SLOT_NUMBER, ORACLE_BLOCK_NUMBER, GENESIS_TIME, ACCOUNT_BALANCE
from src.providers.execution.contracts.base_oracle import BaseOracleContract
from src.providers.execution.contracts.hash_consensus import HashConsensusContract
from src.types import BlockStamp, ReferenceBlockStamp, SlotNumber
from src.types import BlockStamp, ReferenceBlockStamp, SlotNumber, FrameNumber
from src.metrics.prometheus.business import (
ORACLE_MEMBER_LAST_REPORT_REF_SLOT,
FRAME_CURRENT_REF_SLOT,
Expand All @@ -30,6 +31,10 @@
logger = logging.getLogger(__name__)


# Initial epoch is in the future. Revert signature: '0xcd0883ea'
InitialEpochIsYetToArriveRevert = Web3.keccak(text="InitialEpochIsYetToArrive()")[:4].hex()


class ConsensusModule(ABC):
"""
Module that works with Hash Consensus Contract.
Expand Down Expand Up @@ -90,9 +95,23 @@ def get_chain_config(self, blockstamp: BlockStamp) -> ChainConfig:
return consensus_contract.get_chain_config(blockstamp.block_hash)

@lru_cache(maxsize=1)
def get_current_frame(self, blockstamp: BlockStamp) -> CurrentFrame:
def get_initial_or_current_frame(self, blockstamp: BlockStamp) -> CurrentFrame:
consensus_contract = self._get_consensus_contract(blockstamp)
return consensus_contract.get_current_frame(blockstamp.block_hash)

try:
return consensus_contract.get_current_frame(blockstamp.block_hash)
except ContractCustomError as revert:
if revert.data != InitialEpochIsYetToArriveRevert:
raise revert

converter = self._get_web3_converter(blockstamp)

# If initial epoch is not yet arrived then current frame is the first frame
# ref_slot is last slot of previous frame
return CurrentFrame(
ref_slot=converter.get_frame_last_slot(FrameNumber(0 - 1)),
report_processing_deadline_slot=converter.get_frame_last_slot(FrameNumber(0)),
)

@lru_cache(maxsize=1)
def get_initial_ref_slot(self, blockstamp: BlockStamp) -> SlotNumber:
Expand All @@ -109,7 +128,7 @@ def get_member_info(self, blockstamp: BlockStamp) -> MemberInfo:
consensus_contract = self._get_consensus_contract(blockstamp)

# Defaults for dry mode
current_frame = self.get_current_frame(blockstamp)
current_frame = self.get_initial_or_current_frame(blockstamp)
frame_config = self.get_frame_config(blockstamp)
is_member = is_submit_member = is_fast_lane = True
last_member_report_ref_slot = SlotNumber(0)
Expand Down Expand Up @@ -196,10 +215,7 @@ def get_blockstamp_for_report(self, last_finalized_blockstamp: BlockStamp) -> Re
logger.info({'msg': 'Deadline missed.'})
return None

chain_config = self.get_chain_config(last_finalized_blockstamp)
frame_config = self.get_frame_config(last_finalized_blockstamp)

converter = Web3Converter(chain_config, frame_config)
converter = self._get_web3_converter(last_finalized_blockstamp)

bs = get_reference_blockstamp(
cc=self.w3.cc,
Expand Down Expand Up @@ -411,10 +427,7 @@ def _get_slot_delay_before_data_submit(self, blockstamp: BlockStamp) -> int:

mem_position = members.index(variables.ACCOUNT.address)

frame_config = self.get_frame_config(blockstamp)
chain_config = self.get_chain_config(blockstamp)

converter = Web3Converter(chain_config, frame_config)
converter = self._get_web3_converter(blockstamp)

current_frame_number = converter.get_frame_by_slot(blockstamp.slot_number)
current_position = current_frame_number % len(members)
Expand All @@ -429,6 +442,11 @@ def _get_slot_delay_before_data_submit(self, blockstamp: BlockStamp) -> int:
logger.info({'msg': 'Calculate slots delay.', 'value': total_delay})
return total_delay

def _get_web3_converter(self, blockstamp: BlockStamp) -> Web3Converter:
chain_config = self.get_chain_config(blockstamp)
frame_config = self.get_frame_config(blockstamp)
return Web3Converter(chain_config, frame_config)

@abstractmethod
@lru_cache(maxsize=1)
def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple:
Expand Down
11 changes: 11 additions & 0 deletions tests/factory/base_oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from src.modules.accounting.types import AccountingProcessingState
from src.modules.ejector.types import EjectorProcessingState
from tests.factory.web3_factory import Web3Factory


class AccountingProcessingStateFactory(Web3Factory):
__model__ = AccountingProcessingState


class EjectorProcessingStateFactory(Web3Factory):
__model__ = EjectorProcessingState
1 change: 0 additions & 1 deletion tests/factory/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import overload
from src.modules.accounting.types import OracleReportLimits
from src.modules.submodules.types import ChainConfig, FrameConfig
from src.providers.consensus.types import (
Expand Down
32 changes: 30 additions & 2 deletions tests/modules/accounting/test_accounting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
from unittest.mock import Mock, patch

import pytest
from web3.exceptions import ContractCustomError
from web3.types import Wei

from src import variables
from src.modules.accounting import accounting as accounting_module
from src.modules.accounting.accounting import Accounting
from src.modules.accounting.accounting import logger as accounting_logger
from src.modules.accounting.third_phase.types import FormatList
from src.modules.accounting.types import LidoReportRebase
from src.modules.accounting.types import LidoReportRebase, AccountingProcessingState
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import ChainConfig, FrameConfig
from src.modules.submodules.types import ChainConfig, FrameConfig, CurrentFrame, ZERO_HASH
from src.services.withdrawal import Withdrawal
from src.types import BlockStamp, ReferenceBlockStamp
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule
from tests.factory.base_oracle import AccountingProcessingStateFactory
from tests.factory.blockstamp import BlockStampFactory, ReferenceBlockStampFactory
from tests.factory.configs import ChainConfigFactory, FrameConfigFactory
from tests.factory.contract_responses import LidoReportRebaseFactory
from tests.factory.no_registry import LidoValidatorFactory, StakingModuleFactory
from tests.web3_extentions.test_lido_validators import blockstamp


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -473,3 +476,28 @@ def test_is_bunker(
accounting.bunker_service.is_bunker_mode.reset_mock()
accounting._is_bunker(ref_bs)
accounting.bunker_service.is_bunker_mode.assert_not_called()


def test_accounting_get_processing_state_no_yet_init_epoch(accounting: Accounting):
bs = ReferenceBlockStampFactory.build()

accounting.report_contract.get_processing_state = Mock(side_effect=ContractCustomError('0xcd0883ea', '0xcd0883ea'))
accounting.get_initial_or_current_frame = Mock(
return_value=CurrentFrame(ref_slot=100, report_processing_deadline_slot=200)
)
processing_state = accounting._get_processing_state(bs)

assert isinstance(processing_state, AccountingProcessingState)
assert processing_state.current_frame_ref_slot == 100
assert processing_state.processing_deadline_time == 200
assert processing_state.main_data_submitted == False
assert processing_state.main_data_hash == ZERO_HASH


def test_accounting_get_processing_state(accounting: Accounting):
bs = ReferenceBlockStampFactory.build()
accounting_processing_state = AccountingProcessingStateFactory.build()
accounting.report_contract.get_processing_state = Mock(return_value=accounting_processing_state)
result = accounting._get_processing_state(bs)

assert accounting_processing_state == result
2 changes: 1 addition & 1 deletion tests/modules/csm/test_csm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_current_frame_range(module: CSOracle, csm: CSM, mock_chain_config: NoRe
)

csm.get_csm_last_processing_ref_slot = Mock(return_value=param.last_processing_ref_slot)
module.get_current_frame = Mock(
module.get_initial_or_current_frame = Mock(
return_value=CurrentFrame(
ref_slot=SlotNumber(param.current_ref_slot),
report_processing_deadline_slot=SlotNumber(0),
Expand Down
28 changes: 27 additions & 1 deletion tests/modules/ejector/test_ejector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock

import pytest
from web3.exceptions import ContractCustomError

from src import constants
from src.constants import MAX_EFFECTIVE_BALANCE
Expand All @@ -10,12 +11,13 @@
from src.modules.ejector.ejector import logger as ejector_logger
from src.modules.ejector.types import EjectorProcessingState
from src.modules.submodules.oracle_module import ModuleExecuteDelay
from src.modules.submodules.types import ChainConfig
from src.modules.submodules.types import ChainConfig, CurrentFrame
from src.types import BlockStamp, ReferenceBlockStamp
from src.utils import validator_state
from src.web3py.extensions.contracts import LidoContracts
from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModuleId
from src.web3py.types import Web3
from tests.factory.base_oracle import EjectorProcessingStateFactory
from tests.factory.blockstamp import BlockStampFactory, ReferenceBlockStampFactory
from tests.factory.configs import ChainConfigFactory
from tests.factory.no_registry import LidoValidatorFactory
Expand Down Expand Up @@ -426,3 +428,27 @@ def test_get_latest_exit_epoch(ejector: Ejector, blockstamp: BlockStamp) -> None
(max_epoch, count) = ejector._get_latest_exit_epoch(blockstamp)
assert count == 2, "Unexpected count of exiting validators"
assert max_epoch == 42, "Unexpected max epoch"


def test_ejector_get_processing_state_no_yet_init_epoch(ejector: Ejector):
bs = ReferenceBlockStampFactory.build()

ejector.report_contract.get_processing_state = Mock(side_effect=ContractCustomError('0xcd0883ea', '0xcd0883ea'))
ejector.get_initial_or_current_frame = Mock(
return_value=CurrentFrame(ref_slot=100, report_processing_deadline_slot=200)
)
processing_state = ejector._get_processing_state(bs)

assert isinstance(processing_state, EjectorProcessingState)
assert processing_state.current_frame_ref_slot == 100
assert processing_state.processing_deadline_time == 200
assert processing_state.data_submitted == False


def test_ejector_get_processing_state(ejector: Ejector):
bs = ReferenceBlockStampFactory.build()
accounting_processing_state = EjectorProcessingStateFactory.build()
ejector.report_contract.get_processing_state = Mock(return_value=accounting_processing_state)
result = ejector._get_processing_state(bs)

assert accounting_processing_state == result
Loading

0 comments on commit c12a0a2

Please sign in to comment.