Skip to content

Commit

Permalink
patch touch_account in fast mode
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper committed Sep 12, 2023
1 parent 86d7017 commit 11c11c3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
19 changes: 14 additions & 5 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from boa.util.eip1167 import extract_eip1167_address, is_eip1167_contract
from boa.util.lrudict import lrudict
from boa.vm.fork import AccountDBFork
from boa.vm.fast_accountdb import FastAccountDB
from boa.vm.fast_accountdb import FastAccountDB, patch_pyevm_state_object, unpatch_pyevm_state_object
from boa.vm.gas_meters import GasMeter, NoGasMeter, ProfilingGasMeter
from boa.vm.utils import to_bytes, to_int

Expand Down Expand Up @@ -336,7 +336,7 @@ def apply_computation(cls, state, msg, tx_ctx):
addr = msg.code_address
contract = cls.env._lookup_contract_fast(addr) if addr else None
#print("ENTER", Address(msg.code_address or bytes([0]*20)), contract)
if contract is None or not cls.env._enable_fast_mode:
if contract is None or not cls.env._fast_mode_enabled:
#print("SLOW MODE")
return super().apply_computation(state, msg, tx_ctx)

Expand Down Expand Up @@ -382,7 +382,7 @@ class Env:
_singleton = None
_initial_address_counter = 100
_coverage_enabled = False
_enable_fast_mode = False
_fast_mode_enabled = False

def __init__(self):
self.chain = _make_chain()
Expand Down Expand Up @@ -414,6 +414,7 @@ def get_gas_price(self):

def _init_vm(self, reset_traces=True):
self.vm = self.chain.get_vm()

self.vm.patch = VMPatcher(self.vm)

c = type(
Expand All @@ -422,9 +423,10 @@ def _init_vm(self, reset_traces=True):
{"env": self},
)

if self._fast_mode_enabled:
self.vm._state_class.account_db_class = FastAccountDB

self.vm.state.computation_class = c
# TODO: enable this with fast mode
# self.vm.state.account_db_class = FastAccountDB

# we usually want to reset the trace data structures
# but sometimes don't, give caller the option.
Expand All @@ -446,6 +448,13 @@ def _trace_sstore(self, account, slot):
# zero entries.
self.sstore_trace[account].add(slot)

def enable_fast_mode(self, flag: bool = True):
self._fast_mode_enabled = flag
if flag:
patch_pyevm_state_object(self.vm.state)
else:
unpatch_pyevm_state_object(self.vm.state)

def fork(self, url, reset_traces=True, **kwargs):
kwargs["url"] = url
AccountDBFork._rpc_init_kwargs = kwargs
Expand Down
24 changes: 24 additions & 0 deletions boa/vm/fast_accountdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
from eth.db.account import AccountDB

class FastAccountDB(AccountDB):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# this is a hotspot in super().
def touch_account(self, address):
self._accessed_accounts.add(address)


def _touch_account_patcher(self, address):
self._accessed_accounts.add(address)

_BOA_PATCHED = object()

def patch_pyevm_state_object(state_object):
if getattr(state_object, "__boa_patched__", None) == _BOA_PATCHED:
return
accountdb = state_object._account_db
accountdb._restore_touch_account = accountdb.touch_account
accountdb.touch_account = _touch_account_patcher.__get__(accountdb, AccountDB)
state_object.__boa_patched__ = True

def unpatch_pyevm_state_object(state_object):
if not getattr(state_object, "__boa_patched__", None) == _BOA_PATCHED:
return
accountdb = state_object._account_db
accountdb.touch_account = accountdb._restore_touch_account
delattr(accountdb, "_restore_touch_account")
delattr(state_object, "__boa_patched__")

0 comments on commit 11c11c3

Please sign in to comment.