Skip to content

Commit

Permalink
avoid O(n) search in lookup_selector
Browse files Browse the repository at this point in the history
by making ContractMappingInfo a dict instead of a list
  • Loading branch information
karmacoma-eth committed Aug 1, 2024
1 parent 3be86f6 commit 5328902
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 49 deletions.
43 changes: 26 additions & 17 deletions src/halmos/mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type

SELECTOR_FIELDS = {
Expand Down Expand Up @@ -31,8 +31,19 @@ def from_dict(node: Dict) -> Optional["AstNode"]:
@dataclass
class ContractMappingInfo:
contract_name: str
bytecode: str
nodes: List[AstNode]
bytecode: str | None = None

# indexed by selector
nodes: Dict[str, AstNode] = field(default_factory=dict)

def with_nodes(self, nodes: List[AstNode]) -> "ContractMappingInfo":
for node in nodes:
self.add_node(node)
return self

def add_node(self, node: AstNode) -> None:
# don't overwrite if a node with the same selector already exists
self.nodes.setdefault(node.selector, node)


@dataclass
Expand Down Expand Up @@ -81,16 +92,16 @@ class Mapper(metaclass=SingletonMeta):
def __init__(self):
self._contracts: Dict[str, ContractMappingInfo] = {}

def add(self, contract_name: str, bytecode: str, nodes: List[AstNode]):
def add_mapping(self, mapping: ContractMappingInfo) -> None:
contract_name = mapping.contract_name
if contract_name in self._contracts:
raise ValueError(f"Contract {contract_name} already exists")

value = ContractMappingInfo(contract_name, bytecode, nodes)
self._contracts[contract_name] = value
self._contracts[contract_name] = mapping

def get_or_create(self, contract_name: str) -> ContractMappingInfo:
if contract_name not in self._contracts:
self.add(contract_name, "", [])
self.add_mapping(ContractMappingInfo(contract_name))

return self._contracts[contract_name]

Expand All @@ -113,9 +124,9 @@ def get_by_bytecode(self, bytecode: str) -> Optional[ContractMappingInfo]:

return None

def append_node(self, contract_name: str, node: AstNode):
def add_node(self, contract_name: str | None, node: AstNode):
contract_mapping_info = self.get_or_create(contract_name)
contract_mapping_info.nodes.append(node)
contract_mapping_info.add_node(node)

def parse_ast(self, node: Dict, explain=False):
# top-level public API meant to be called externally, passing the full AST
Expand Down Expand Up @@ -150,7 +161,7 @@ def _parse_ast(
if ast_node and ast_node.selector != "0x":
print(f"adding {ast_node}")

self.append_node(contract_name, ast_node)
self.add_node(contract_name, ast_node)
expl.add(f" (added node with {ast_node.selector=}")

# go one level deeper
Expand All @@ -168,16 +179,14 @@ def lookup_selector(self, selector: str, contract_name: str | None = None) -> st
if contract_name:
contract_mapping_info = self.get_by_name(contract_name)
if contract_mapping_info:
for node in contract_mapping_info.nodes:
if node.selector == selector:
return node.name
if node := contract_mapping_info.nodes.get(selector, None):
return node.name

# otherwise, search for the signature in other contracts and return the first match.
# note: ambiguity may occur if multiple compilation units exist.
for contract_info in self._contracts.values():
for node in contract_info.nodes:
if node.selector == selector:
return node.name
for contract_mapping_info in self._contracts.values():
if node := contract_mapping_info.nodes.get(selector, None):
return node.name

return selector

Expand Down
77 changes: 45 additions & 32 deletions tests/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import pytest

from halmos.mapper import AstNode, Mapper, SingletonMeta
from halmos.mapper import AstNode, ContractMappingInfo, Mapper, SingletonMeta


@pytest.fixture
Expand Down Expand Up @@ -37,6 +37,11 @@ def ast_nodes() -> List[AstNode]:
]


@pytest.fixture
def mapping(ast_nodes) -> ContractMappingInfo:
return ContractMappingInfo("ContractA", "bytecodeA").with_nodes(ast_nodes)


@pytest.fixture
def mapper() -> Mapper:
return Mapper()
Expand All @@ -53,24 +58,23 @@ def test_singleton():
assert mapper1 is mapper2


def test_add(mapper, ast_nodes):
mapper.add("ContractA", "bytecodeA", ast_nodes)
def test_add_mapping(mapper, mapping):
mapper.add_mapping(mapping)
contract_info = mapper.get_by_name("ContractA")
assert contract_info is not None
assert contract_info.contract_name == "ContractA"
assert contract_info.bytecode == "bytecodeA"
assert len(contract_info.nodes) == 2


def test_add_already_existence(mapper, ast_nodes):
mapper.add("ContractA", "bytecodeA", ast_nodes)

def test_add_mapping_already_exists(mapper, mapping):
mapper.add_mapping(mapping)
with pytest.raises(ValueError, match=r"Contract ContractA already exists"):
mapper.add("ContractA", "bytecodeA", ast_nodes)
mapper.add_mapping(mapping)


def test_get_by_name(mapper, ast_nodes):
mapper.add("ContractA", "bytecodeA", ast_nodes)
def test_get_by_name(mapper):
mapper.get_or_create("ContractA")
contract_info = mapper.get_by_name("ContractA")
assert contract_info is not None
assert contract_info.contract_name == "ContractA"
Expand All @@ -81,8 +85,8 @@ def test_get_by_name_nonexistent(mapper):
assert contract_info is None


def test_get_by_bytecode(mapper, ast_nodes):
mapper.add("ContractA", "bytecodeA", ast_nodes)
def test_get_by_bytecode(mapper, mapping):
mapper.add_mapping(mapping)
contract_info = mapper.get_by_bytecode("bytecodeA")
assert contract_info is not None
assert contract_info.bytecode == "bytecodeA"
Expand All @@ -93,21 +97,25 @@ def test_get_by_bytecode_nonexistent(mapper):
assert contract_info is None


def test_append_node(mapper, ast_nodes):
mapper.add("ContractA", "bytecodeA", ast_nodes)
def test_add_node(mapper, mapping):
mapper.add_mapping(mapping)

# when we add a new node to a contract scope
new_node = AstNode(node_type="type3", name="Node3", selector="0x789")
mapper.append_node("ContractA", new_node)
contract_info = mapper.get_by_name("ContractA")
mapper.add_node(mapping.contract_name, new_node)

# then we can retrieve it from the contract scope
contract_info = mapper.get_by_name(mapping.contract_name)
assert contract_info is not None
assert len(contract_info.nodes) == 3
assert contract_info.nodes[-1].name == "Node3"
assert contract_info.nodes[new_node.selector].name == "Node3"


def test_append_node_to_never_seen_before_contract(mapper):
def test_add_node_to_never_seen_before_contract(mapper):
new_node = AstNode(node_type="type3", name="Node3", selector="0x789")

mapper.append_node("NeverSeenBefore", new_node)
assert mapper.get_by_name("NeverSeenBefore").nodes == [new_node]
mapper.add_node("NeverSeenBefore", new_node)
assert mapper.get_by_name("NeverSeenBefore").nodes == {new_node.selector: new_node}


def test_lookup_selector(mapper, ast_nodes):
Expand All @@ -117,7 +125,7 @@ def test_lookup_selector(mapper, ast_nodes):
# when we add a new node to a contract scope
node1 = ast_nodes[0]
selector = node1.selector
mapper.append_node("ContractA", node1)
mapper.add_node("ContractA", node1)

# then we can look up the selector by specifying the contract name
assert mapper.lookup_selector(selector, contract_name="ContractA") == node1.name
Expand All @@ -130,7 +138,7 @@ def test_lookup_selector(mapper, ast_nodes):
node_type=node1.node_type, name="ConflictingNode", selector=selector
)

mapper.append_node("ContractB", node2)
mapper.add_node("ContractB", node2)

# then we can look up the selector by specifying the contract scope
assert mapper.lookup_selector(selector, contract_name="ContractA") == node1.name
Expand All @@ -144,7 +152,7 @@ def test_lookup_selector_unscoped(mapper, ast_nodes):
# when we add a new node with no contract scope (e.g. global errors or events)
node1 = ast_nodes[0]
selector = node1.selector
mapper.append_node(None, node1)
mapper.add_node(None, node1)

# then we can look up the selector even if we specify a contract scope
assert mapper.lookup_selector(selector, contract_name="ContractA") == node1.name
Expand Down Expand Up @@ -176,7 +184,7 @@ def test_parse_simple_ast(mapper):
assert contract_info is not None
assert contract_info.contract_name == "ExampleContract"
assert len(contract_info.nodes) == 1
assert contract_info.nodes[0].name == "exampleFunction"
assert next(iter(contract_info.nodes.values())).name == "exampleFunction"


def test_parse_complex_ast(mapper):
Expand Down Expand Up @@ -231,7 +239,7 @@ def test_parse_complex_ast(mapper):

assert len(contract_info.nodes) == 3

node_names = [node.name for node in contract_info.nodes]
node_names = [node.name for node in contract_info.nodes.values()]
assert "var1" not in node_names # var1 is not added, it has no selector
assert "func1" in node_names
assert "event1" in node_names
Expand All @@ -253,13 +261,16 @@ def test_parse_multicontract_ast(read_json_file, mapper):
# but the free function is not added to the mapper (it has no selector)
not_in_contract = mapper.get_by_name(None)
assert len(not_in_contract.nodes) == 2
assert not_in_contract.nodes[0].name == "Log"
nodes = iter(not_in_contract.nodes.values())
first_node = next(nodes)
assert first_node.name == "Log"
assert (
not_in_contract.nodes[0].selector
first_node.selector
== "0x909c57d5c6ac08245cf2a6de3900e2b868513fa59099b92b27d8db823d92df9c"
)
assert not_in_contract.nodes[1].name == "Unauthorized"
assert not_in_contract.nodes[1].selector == "0x82b42900"
second_node = next(nodes)
assert second_node.name == "Unauthorized"
assert second_node.selector == "0x82b42900"
assert mapper.lookup_selector("0x82b42900") == "Unauthorized"

# there are 2 contracts in the AST
Expand All @@ -279,8 +290,9 @@ def test_parse_multicontract_ast(read_json_file, mapper):
contract_TestA = mapper.get_by_name("TestA")
assert contract_TestA is not None
assert len(contract_TestA.nodes) == 1
assert contract_TestA.nodes[0].name == "test_foo"
assert contract_TestA.nodes[0].selector == "0xdc24e7f1"
node = next(iter(contract_TestA.nodes.values()))
assert node.name == "test_foo"
assert node.selector == "0xdc24e7f1"
assert mapper.lookup_selector("0xdc24e7f1", contract_name="TestA") == "test_foo"

# contract C {
Expand All @@ -289,5 +301,6 @@ def test_parse_multicontract_ast(read_json_file, mapper):
contract_C = mapper.get_by_name("C")
assert contract_C is not None
assert len(contract_C.nodes) == 1
assert contract_C.nodes[0].name == "foo"
assert contract_C.nodes[0].selector == "0xc2985578"
node = next(iter(contract_C.nodes.values()))
assert node.name == "foo"
assert node.selector == "0xc2985578"

0 comments on commit 5328902

Please sign in to comment.