diff --git a/boa/contracts/vyper/ast_utils.py b/boa/contracts/vyper/ast_utils.py index 51963306..1a4b09f0 100644 --- a/boa/contracts/vyper/ast_utils.py +++ b/boa/contracts/vyper/ast_utils.py @@ -1,7 +1,7 @@ import io import re import tokenize -from typing import Any, Optional, Tuple +from typing import Any, Optional import vyper.ast as vy_ast from vyper.codegen.core import getpos @@ -36,7 +36,7 @@ def _extract_reason(comment: str) -> Any: # somewhat heuristic. def reason_at( source_code: str, lineno: int, end_lineno: int -) -> Optional[Tuple[str, str]]: +) -> Optional[tuple[str, str]]: block = get_block(source_code, lineno, end_lineno) c = _get_comment(block) if c is not None: diff --git a/boa/contracts/vyper/event.py b/boa/contracts/vyper/event.py index fa852d30..43f142fe 100644 --- a/boa/contracts/vyper/event.py +++ b/boa/contracts/vyper/event.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, List +from typing import Any @dataclass @@ -7,8 +7,8 @@ class Event: log_id: int # internal py-evm log id, for ordering purposes address: str # checksum address event_type: Any # vyper.semantics.types.user.EventT - topics: List[Any] # list of decoded topics - args: List[Any] # list of decoded args + topics: list[Any] # list of decoded topics + args: list[Any] # list of decoded args def __repr__(self): t_i = 0 diff --git a/boa/contracts/vyper/ir_executor.py b/boa/contracts/vyper/ir_executor.py index 30f1657d..1f21e6d6 100644 --- a/boa/contracts/vyper/ir_executor.py +++ b/boa/contracts/vyper/ir_executor.py @@ -365,8 +365,8 @@ def from_mnemonic(cls, mnemonic): class OpcodeIRExecutor(IRExecutor): _type: type = StackItem # type: ignore - def __init__(self, name, opcode_info, *args): - self.opcode_info: OpcodeInfo = opcode_info + def __init__(self, name: str, opcode_info: OpcodeInfo, *args): + self.opcode_info = opcode_info # to differentiate from implemented codes self._name = "__" + name + "__" diff --git a/boa/environment.py b/boa/environment.py index 08a5881b..b1822884 100644 --- a/boa/environment.py +++ b/boa/environment.py @@ -6,12 +6,12 @@ import contextlib import random -from typing import Any, Optional, Tuple, TypeAlias +from typing import Any, Optional, TypeAlias import eth.constants as constants from eth_typing import Address as PYEVM_Address # it's just bytes. -from boa.rpc import EthereumRPC +from boa.rpc import RPC, EthereumRPC from boa.util.abi import Address from boa.vm.gas_meters import GasMeter, NoGasMeter, ProfilingGasMeter from boa.vm.py_evm import PyEVM @@ -58,10 +58,10 @@ def enable_fast_mode(self, flag: bool = True): self._fast_mode_enabled = flag self.evm.enable_fast_mode(flag) - def fork(self, url=None, reset_traces=True, block_identifier="safe", **kwargs): + def fork(self, url: str, reset_traces=True, block_identifier="safe", **kwargs): return self.fork_rpc(EthereumRPC(url), reset_traces, block_identifier, **kwargs) - def fork_rpc(self, rpc=None, reset_traces=True, block_identifier="safe", **kwargs): + def fork_rpc(self, rpc: RPC, reset_traces=True, block_identifier="safe", **kwargs): """ Fork the environment to a local chain. :param rpc: RPC to fork from @@ -207,7 +207,7 @@ def deploy_code( start_pc: int = 0, # TODO: This isn't used # override the target address: override_address: Optional[_AddressType] = None, - ) -> Tuple[Address, bytes]: + ) -> tuple[Address, bytes]: sender = self._get_sender(sender) target_address = ( diff --git a/boa/test/strategies.py b/boa/test/strategies.py index 87f3a309..c075663c 100644 --- a/boa/test/strategies.py +++ b/boa/test/strategies.py @@ -1,6 +1,6 @@ import random import string -from typing import Any, Callable, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Iterable, Optional, Union from eth_abi.grammar import BasicType, TupleType, parse from eth_utils import to_checksum_address @@ -23,7 +23,7 @@ # note: there are also utils in the vyper codebase we could use for # this. also, in the future we may want to replace these with strategies # that use vyper types instead of abi types. -def get_int_bounds(type_str: str) -> Tuple[int, int]: +def get_int_bounds(type_str: str) -> tuple[int, int]: """Returns the lower and upper bound for an integer type.""" size = int(type_str.strip("uint") or 256) if size < 8 or size > 256 or size % 8: @@ -43,7 +43,7 @@ def __repr__(self): def _exclude_filter(fn: Callable) -> Callable: - def wrapper(*args: Tuple, exclude: Any = None, **kwargs: int) -> SearchStrategy: + def wrapper(*args: tuple, exclude: Any = None, **kwargs: int) -> SearchStrategy: strat = fn(*args, **kwargs) if exclude is None: return strat @@ -62,7 +62,7 @@ def wrapper(*args: Tuple, exclude: Any = None, **kwargs: int) -> SearchStrategy: def _check_numeric_bounds( type_str: str, min_value: NumberType, max_value: NumberType -) -> Tuple: +) -> tuple[NumberType, NumberType]: lower, upper = get_int_bounds(type_str) min_final = lower if min_value is None else min_value max_final = upper if max_value is None else max_value @@ -75,8 +75,8 @@ def _check_numeric_bounds( def _integer_strategy( type_str: str, min_value: Optional[int] = None, max_value: Optional[int] = None ) -> SearchStrategy: - min_value, max_value = _check_numeric_bounds(type_str, min_value, max_value) - return st.integers(min_value=min_value, max_value=max_value) + min_val, max_val = _check_numeric_bounds(type_str, min_value, max_value) + return st.integers(min_val, max_val) @_exclude_filter diff --git a/boa/vm/fork.py b/boa/vm/fork.py index 088779b2..5832668e 100644 --- a/boa/vm/fork.py +++ b/boa/vm/fork.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, Type from requests import HTTPError @@ -120,16 +120,21 @@ def fetch_multi(self, payload): # AccountDB which dispatches to an RPC when we don't have the # data locally class AccountDBFork(AccountDB): - _rpc: RPC = None # type: ignore - _rpc_init_kwargs: dict[str, Any] = {} - - def __init__(self, *args, **kwargs): + @classmethod + def class_from_rpc( + cls, rpc: RPC, block_identifier: str, **kwargs + ) -> Type["AccountDBFork"]: + class _ConfiguredAccountDB(AccountDBFork): + def __init__(self, *args, **kwargs2): + caching_rpc = CachingRPC(rpc, **kwargs) + super().__init__(caching_rpc, block_identifier, *args, **kwargs2) + + return _ConfiguredAccountDB + + def __init__(self, rpc: CachingRPC, block_identifier: str, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - rpc_kwargs = self._rpc_init_kwargs.copy() - - block_identifier = rpc_kwargs.pop("block_identifier", "safe") - self._rpc: CachingRPC = CachingRPC(self._rpc, **rpc_kwargs) + self._rpc = rpc if block_identifier not in _PREDEFINED_BLOCKS: block_identifier = to_hex(block_identifier) diff --git a/boa/vm/py_evm.py b/boa/vm/py_evm.py index a7b770de..409a4beb 100644 --- a/boa/vm/py_evm.py +++ b/boa/vm/py_evm.py @@ -375,7 +375,7 @@ def _init_vm( fast_mode_enabled: bool, ): self.vm = self.chain.get_vm() - self._set_account_db_class(account_db_class) + self.vm.__class__._state_class.account_db_class = account_db_class self.vm.patch = VMPatcher(self.vm) @@ -416,10 +416,16 @@ def enable_fast_mode(self, flag: bool = True): else: unpatch_pyevm_state_object(self.vm.state) - def fork_rpc(self, rpc: RPC, reset_traces: bool, fast_mode_enabled: bool, **kwargs): - AccountDBFork._rpc = rpc - AccountDBFork._rpc_init_kwargs = kwargs - self._init_vm(self.env, AccountDBFork, reset_traces, fast_mode_enabled) + def fork_rpc( + self, + rpc: RPC, + reset_traces: bool, + fast_mode_enabled: bool, + block_identifier: str, + **kwargs, + ): + account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs) + self._init_vm(self.env, account_db_class, reset_traces, fast_mode_enabled) block_info = self.vm.state._account_db._block_info self.vm.patch.timestamp = int(block_info["timestamp"], 16)