Skip to content

Commit

Permalink
fix: simplify symbol parsing (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
karmacoma-eth authored Aug 1, 2024
1 parent 31bb794 commit 954ed59
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 153 deletions.
49 changes: 21 additions & 28 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def rendered_trace(context: CallContext) -> str:
return output.getvalue()


def rendered_calldata(calldata: ByteVec, contract_name: str = None) -> str:
def rendered_calldata(calldata: ByteVec, contract_name: str | None = None) -> str:
return hexify(calldata.unwrap(), contract_name) if calldata else "0x"


Expand All @@ -301,11 +301,7 @@ def render_trace(context: CallContext, file=sys.stdout) -> None:
if context.output.error is None:
target = hex(int(str(message.target)))
bytecode = context.output.data.unwrap().hex()
contract_name = (
Mapper()
.get_contract_mapping_info_by_bytecode(bytecode)
.contract_name
)
contract_name = Mapper().get_by_bytecode(bytecode).contract_name

DeployAddressMapper().add_deployed_contract(target, contract_name)
addr_str = contract_name
Expand Down Expand Up @@ -1369,29 +1365,8 @@ def parse_build_out(args: HalmosConfig) -> Dict:
)

contract_map[contract_name] = (json_out, contract_type, natspec)
parse_symbols(args, contract_map, contract_name)

try:
bytecode = contract_map[contract_name][0]["bytecode"]["object"]
contract_mapping_info = Mapper().get_contract_mapping_info_by_name(
contract_name
)

if contract_mapping_info is None:
Mapper().add_contract_mapping_info(
contract_name=contract_name,
bytecode=bytecode,
nodes=[],
)
else:
contract_mapping_info.bytecode = bytecode

contract_mapping_info = Mapper().get_contract_mapping_info_by_name(
contract_name
)
Mapper().parse_ast(contract_map[contract_name][0]["ast"])

except Exception:
pass
except Exception as err:
warn_code(
PARSING_ERROR,
Expand All @@ -1404,6 +1379,24 @@ def parse_build_out(args: HalmosConfig) -> Dict:
return result


def parse_symbols(args: HalmosConfig, contract_map: Dict, contract_name: str) -> None:
try:
json_out = contract_map[contract_name][0]
bytecode = json_out["bytecode"]["object"]
contract_mapping_info = Mapper().get_or_create(contract_name)
contract_mapping_info.bytecode = bytecode

Mapper().parse_ast(json_out["ast"])

except Exception:
if args.debug:
debug(f"error parsing symbols for contract {contract_name}")
debug(traceback.format_exc())
else:
# we parse symbols as best effort, don't propagate exceptions
pass


def parse_devdoc(funsig: str, contract_json: Dict) -> str:
try:
return contract_json["metadata"]["output"]["devdoc"]["methods"][funsig][
Expand Down
218 changes: 133 additions & 85 deletions src/halmos/mapper.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,69 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type

SELECTOR_FIELDS = {
"VariableDeclaration": "functionSelector",
"FunctionDefinition": "functionSelector",
"EventDefinition": "eventSelector",
"ErrorDefinition": "errorSelector",
}


@dataclass
class AstNode:
node_type: str
id: int
name: str
address: str # TODO: rename it to `selector` or `signature` to better reflect the meaning
visibility: str
selector: str

@staticmethod
def from_dict(node: Dict) -> Optional["AstNode"]:
node_type = node["nodeType"]
selector_field = SELECTOR_FIELDS.get(node_type, None)
if selector_field is None:
return None

selector = "0x" + node.get(selector_field, "")
return AstNode(
node_type=node_type, name=node.get("name", ""), selector=selector
)


@dataclass
class ContractMappingInfo:
contract_name: str
bytecode: str
nodes: List[AstNode]
bytecode: str | None = None

# indexed by selector
nodes: Dict[str, AstNode] = field(default_factory=dict)

def with_nodes(self, nodes: List[AstNode]) -> "ContractMappingInfo":
for node in nodes:
self.add_node(node)
return self

def add_node(self, node: AstNode) -> None:
# don't overwrite if a node with the same selector already exists
self.nodes.setdefault(node.selector, node)


@dataclass
class Explanation:
enabled: bool = False
content: str = ""

def add(self, text: str):
if self.enabled:
self.content += text

def print(self):
if self.enabled:
print(self.content)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.print()


class SingletonMeta(type):
Expand Down Expand Up @@ -44,24 +92,23 @@ class Mapper(metaclass=SingletonMeta):
def __init__(self):
self._contracts: Dict[str, ContractMappingInfo] = {}

def add_contract_mapping_info(
self, contract_name: str, bytecode: str, nodes: List[AstNode]
):
def add_mapping(self, mapping: ContractMappingInfo) -> None:
contract_name = mapping.contract_name
if contract_name in self._contracts:
raise ValueError(f"Contract {contract_name} already exists")

self._contracts[contract_name] = ContractMappingInfo(
contract_name, bytecode, nodes
)
self._contracts[contract_name] = mapping

def get_or_create(self, contract_name: str) -> ContractMappingInfo:
if contract_name not in self._contracts:
self.add_mapping(ContractMappingInfo(contract_name))

def get_contract_mapping_info_by_name(
self, contract_name: str
) -> Optional[ContractMappingInfo]:
return self._contracts[contract_name]

def get_by_name(self, contract_name: str) -> Optional[ContractMappingInfo]:
return self._contracts.get(contract_name, None)

def get_contract_mapping_info_by_bytecode(
self, bytecode: str
) -> Optional[ContractMappingInfo]:
def get_by_bytecode(self, bytecode: str) -> Optional[ContractMappingInfo]:
# TODO: Handle cases for contracts with immutable variables
# Current implementation might not work correctly if the following code is added the test solidity file
#
Expand All @@ -77,93 +124,69 @@ def get_contract_mapping_info_by_bytecode(

return None

def append_node(self, contract_name: str, node: AstNode):
contract_mapping_info = self.get_contract_mapping_info_by_name(contract_name)
def add_node(self, contract_name: str | None, node: AstNode):
contract_mapping_info = self.get_or_create(contract_name)
contract_mapping_info.add_node(node)

if contract_mapping_info is None:
raise ValueError(f"Contract {contract_name} not found")
def parse_ast(self, node: Dict, explain=False):
# top-level public API meant to be called externally, passing the full AST
self._parse_ast(node, contract_name=None, explain=explain, _depth=0)

contract_mapping_info.nodes.append(node)
### internal methods

def parse_ast(self, node: Dict, contract_name: str = ""):
def _parse_ast(
self, node: Dict, contract_name: str | None = None, explain=False, _depth=0
):
node_type = node["nodeType"]
node_name = node.get("name", None)
node_name_str = f": {node_name}" if node_name else ""

if node_type in self._PARSING_IGNORED_NODE_TYPES:
return
with Explanation(enabled=explain) as expl:
expl.add(f"{' ' * _depth}{node_type}{node_name_str}")

current_contract = self._get_current_contract(node, contract_name)
if node_type in self._PARSING_IGNORED_NODE_TYPES:
expl.add(" (ignored node type)")
return

if node_type == "ContractDefinition":
if current_contract not in self._contracts:
self.add_contract_mapping_info(
contract_name=current_contract, bytecode="", nodes=[]
)
if node_type == "ContractDefinition":
if contract_name is not None:
raise ValueError(f"parsing {contract_name} but found {node}")

if self.get_contract_mapping_info_by_name(current_contract).nodes:
return
elif node_type != "SourceUnit":
id, name, address, visibility = self._get_node_info(node, node_type)
contract_name = node["name"]
if self.get_or_create(contract_name).nodes:
expl.add(" (skipped, already parsed)")
return

self.append_node(
current_contract,
AstNode(node_type, id, name, address, visibility),
)
ast_node = AstNode.from_dict(node)
if ast_node and ast_node.selector != "0x":
self.add_node(contract_name, ast_node)
expl.add(f" (added node with {ast_node.selector=}")

# go one level deeper
for child_node in node.get("nodes", []):
self.parse_ast(child_node, current_contract)
self._parse_ast(child_node, contract_name, explain, _depth + 1)

if "body" in node:
self.parse_ast(node["body"], current_contract)
if body := node.get("body", None):
self._parse_ast(body, contract_name, explain, _depth + 1)

def _get_node_info(self, node: Dict, node_type: str) -> Dict:
return (
node.get("id", ""),
node.get("name", ""),
"0x" + self._get_node_address(node, node_type),
node.get("visibility", ""),
)
def lookup_selector(self, selector: str, contract_name: str | None = None) -> str:
if selector == "0x":
return selector

def _get_node_address(self, node: Dict, node_type: str) -> str:
address_fields = {
"VariableDeclaration": "functionSelector",
"FunctionDefinition": "functionSelector",
"EventDefinition": "eventSelector",
"ErrorDefinition": "errorSelector",
}

return node.get(address_fields.get(node_type, ""), "")

def _get_current_contract(self, node: Dict, contract_name: str) -> str:
return (
node.get("name", "")
if node["nodeType"] == "ContractDefinition"
else contract_name
)

def find_nodes_by_address(self, address: str, contract_name: str = None):
# if the given signature is declared in the given contract, return its name.
if contract_name:
contract_mapping_info = self.get_contract_mapping_info_by_name(
contract_name
)

contract_mapping_info = self.get_by_name(contract_name)
if contract_mapping_info:
for node in contract_mapping_info.nodes:
if node.address == address:
return node.name
if node := contract_mapping_info.nodes.get(selector, None):
return node.name

# otherwise, search for the signature in other contracts, and return all the contracts that declare it.
# otherwise, search for the signature in other contracts and return the first match.
# note: ambiguity may occur if multiple compilation units exist.
result = ""
for key, contract_info in self._contracts.items():
matching_nodes = [
node for node in contract_info.nodes if node.address == address
]

for node in matching_nodes:
result += f"{key}.{node.name} "
for contract_mapping_info in self._contracts.values():
if node := contract_mapping_info.nodes.get(selector, None):
return node.name

return result.strip() if result != "" and address != "0x" else address
return selector


# TODO: create a new instance or reset for each test
Expand Down Expand Up @@ -192,3 +215,28 @@ def add_deployed_contract(

def get_deployed_contract(self, address: str) -> Optional[str]:
return self._deployed_contracts.get(address, address)


def main():
import sys
import json
from .utils import cyan

def read_json_file(file_path: str) -> Dict:
with open(file_path) as f:
return json.load(f)

mapper = Mapper()
json_out = read_json_file(sys.argv[1])
mapper.parse_ast(json_out["ast"], explain=True)

print(cyan("\n### Results ###\n"))
for contract_name in mapper._contracts.keys():
print(f"Contract: {contract_name}")
ast_nodes = mapper.get_by_name(contract_name).nodes
for selector, node in ast_nodes.items():
print(f" {selector}: {node.name}")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions src/halmos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,13 @@ def hexify(x, contract_name: str = None):
elif isinstance(x, int):
return f"0x{x:02x}"
elif isinstance(x, bytes):
return Mapper().find_nodes_by_address("0x" + x.hex(), contract_name)
return Mapper().lookup_selector("0x" + x.hex(), contract_name)
elif hasattr(x, "unwrap"):
return hexify(x.unwrap(), contract_name)
elif is_bv_value(x):
# maintain the byte size of x
num_bytes = byte_length(x, strict=False)
return Mapper().find_nodes_by_address(
return Mapper().lookup_selector(
f"0x{x.as_long():0{num_bytes * 2}x}", contract_name
)
elif is_app(x):
Expand Down
Loading

0 comments on commit 954ed59

Please sign in to comment.