diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index b8aa05a5..9c5db85d 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -5,9 +5,13 @@ on: branches: [main] paths: - .github/workflows/test-external.yml -# pull_request: -# branches: [main] workflow_dispatch: + inputs: + halmos-options: + description: "additional halmos options" + required: false + type: string + default: "" jobs: test: @@ -16,27 +20,38 @@ jobs: strategy: fail-fast: false matrix: - include: + cache-solver: ["", "--cache-solver"] + project: - repo: "morpho-org/morpho-data-structures" dir: "morpho-data-structures" cmd: "halmos --function testProve --loop 4 --symbolic-storage" branch: "" + profile: "" - repo: "a16z/cicada" dir: "cicada" cmd: "halmos --contract LibUint1024Test --function testProve --loop 256" branch: "" + profile: "" - repo: "a16z/cicada" dir: "cicada" cmd: "halmos --contract LibPrimeTest --function testProve --loop 256" branch: "" + profile: "" - repo: "farcasterxyz/contracts" dir: "farcaster-contracts" cmd: "halmos" branch: "" + profile: "" - repo: "zobront/halmos-solady" dir: "halmos-solady" cmd: "halmos --function testCheck" branch: "" + profile: "" + - repo: "pcaversaccio/snekmate" + dir: "snekmate" + cmd: "halmos --config test/halmos.toml" + branch: "" + profile: "halmos" steps: - name: Checkout @@ -49,11 +64,40 @@ jobs: - name: Checkout external repo uses: actions/checkout@v4 with: - repository: ${{ matrix.repo }} - path: ${{ matrix.dir }} - ref: ${{ matrix.branch }} + repository: ${{ matrix.project.repo }} + path: ${{ matrix.project.dir }} + ref: ${{ matrix.project.branch }} submodules: recursive + - name: Install Foundry + uses: foundry-rs/foundry-toolchain@v1 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies + run: python -m pip install --upgrade pip + + - name: Install Halmos + run: python -m pip install -e ./halmos + + - name: Install Vyper + if: ${{ matrix.project.dir == 'snekmate' }} + run: python -m pip install vyper + + - name: Install Yices 2 SMT solver + run: | + wget https://github.com/SRI-CSL/yices2/releases/download/Yices-2.6.4/yices-2.6.4-x86_64-pc-linux-gnu.tar.gz + tar -xzvf yices-2.6.4-x86_64-pc-linux-gnu.tar.gz + sudo mv yices-2.6.4/bin/* /usr/local/bin/ + sudo mv yices-2.6.4/lib/* /usr/local/lib/ + sudo mv yices-2.6.4/include/* /usr/local/include/ + rm -rf yices-2.6.4 + - name: Test external repo - run: ${{ matrix.cmd }} --statistics --debug --solver-timeout-assertion 0 --solver-threads 4 --solver-command=yices-smt2 - working-directory: ${{ matrix.dir }} + run: ${{ matrix.project.cmd }} --statistics --solver-timeout-assertion 0 --solver-threads 4 --solver-command yices-smt2 ${{ matrix.cache-solver }} ${{ inputs.halmos-options }} + working-directory: ${{ matrix.project.dir }} + env: + FOUNDRY_PROFILE: ${{ matrix.project.profile }} diff --git a/.github/workflows/test-long.yml b/.github/workflows/test-long.yml index fc1ad249..faa0e826 100644 --- a/.github/workflows/test-long.yml +++ b/.github/workflows/test-long.yml @@ -6,6 +6,12 @@ on: paths: - .github/workflows/test-long.yml workflow_dispatch: + inputs: + halmos-options: + description: "additional halmos options" + required: false + type: string + default: "" jobs: test: @@ -14,11 +20,12 @@ jobs: strategy: fail-fast: false matrix: - include: - - testname: "tests/solver" - - testname: "examples/simple" - - testname: "examples/tokens/ERC20" - - testname: "examples/tokens/ERC721" + cache-solver: ["", "--cache-solver"] + testname: + - "tests/solver" + - "examples/simple" + - "examples/tokens/ERC20" + - "examples/tokens/ERC721" steps: - name: Login to GitHub Container Registry @@ -39,4 +46,4 @@ jobs: - name: Run pytest run: | - docker run -v .:/workspace --entrypoint pytest halmos -x -v tests/test_halmos.py -k ${{ matrix.testname }} --halmos-options='-v -st --solver-timeout-assertion 0 --solver-threads 6 --solver-command=yices-smt2' -s --log-cli-level= + docker run -v .:/workspace --entrypoint pytest halmos -x -v tests/test_halmos.py -k ${{ matrix.testname }} --halmos-options='-st --solver-timeout-assertion 0 --solver-threads 4 --solver-command yices-smt2 ${{ matrix.cache-solver }} ${{ inputs.halmos-options }}' -s --log-cli-level= diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1b6c0dd1..259c2357 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -8,6 +8,17 @@ on: pull_request: branches: [main] workflow_dispatch: + inputs: + halmos-options: + description: "additional halmos options" + required: false + type: string + default: "" + pytest-options: + description: "additional pytest options" + required: false + type: string + default: "" jobs: test: @@ -44,4 +55,4 @@ jobs: run: python -m pip install -e . - name: Run pytest - run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="-v -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0" + run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="-st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }} diff --git a/examples/simple/foundry.toml b/examples/simple/foundry.toml index e32f5716..f9e9259f 100644 --- a/examples/simple/foundry.toml +++ b/examples/simple/foundry.toml @@ -3,7 +3,8 @@ src = 'src' out = 'out' libs = ["../../tests/lib", 'lib'] -# See more config options https://github.com/foundry-rs/foundry/tree/master/config +evm_version = 'cancun' +force = false # compile options used by halmos (to prevent unnecessary recompilation when running forge test and halmos together) extra_output = ["storageLayout", "metadata"] diff --git a/examples/tokens/ERC20/foundry.toml b/examples/tokens/ERC20/foundry.toml index d85bbc96..26c57074 100644 --- a/examples/tokens/ERC20/foundry.toml +++ b/examples/tokens/ERC20/foundry.toml @@ -2,3 +2,5 @@ src = "src" out = "out" libs = ["../../../tests/lib", "lib"] + +evm_version = 'cancun' diff --git a/examples/tokens/ERC721/foundry.toml b/examples/tokens/ERC721/foundry.toml index d85bbc96..26c57074 100644 --- a/examples/tokens/ERC721/foundry.toml +++ b/examples/tokens/ERC721/foundry.toml @@ -2,3 +2,5 @@ src = "src" out = "out" libs = ["../../../tests/lib", "lib"] + +evm_version = 'cancun' diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index 6421eaa3..c5b13a24 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -9,7 +9,6 @@ import time import traceback import uuid - from collections import Counter from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import asdict, dataclass @@ -18,13 +17,9 @@ from .bytevec import ByteVec from .calldata import Calldata -from .config import ( - arg_parser, - default_config, - resolve_config_files, - toml_parser, - Config as HalmosConfig, -) +from .config import Config as HalmosConfig +from .config import arg_parser, default_config, resolve_config_files, toml_parser +from .mapper import DeployAddressMapper, Mapper from .sevm import * from .utils import ( NamedTimer, @@ -281,18 +276,16 @@ def rendered_trace(context: CallContext) -> str: return output.getvalue() -def rendered_calldata(calldata: ByteVec) -> str: - return hexify(calldata.unwrap()) if calldata else "0x" +def rendered_calldata(calldata: ByteVec, contract_name: str = None) -> str: + return hexify(calldata.unwrap(), contract_name) if calldata else "0x" def render_trace(context: CallContext, file=sys.stdout) -> None: - # TODO: label for known addresses - # TODO: decode calldata - # TODO: decode logs - message = context.message addr = unbox_int(message.target) addr_str = str(addr) if is_bv(addr) else hex(addr) + # check if we have a contract name for this address in our deployment mapper + addr_str = DeployAddressMapper().get_deployed_contract(addr_str) value = unbox_int(message.value) value_str = f" (value: {value})" if is_bv(value) or value > 0 else "" @@ -303,13 +296,29 @@ def render_trace(context: CallContext, file=sys.stdout) -> None: if message.is_create(): # TODO: select verbosity level to render full initcode # initcode_str = rendered_initcode(context) + + try: + if context.output.error is None: + target = hex(int(str(message.target))) + bytecode = context.output.data.unwrap().hex() + contract_name = ( + Mapper() + .get_contract_mapping_info_by_bytecode(bytecode) + .contract_name + ) + + DeployAddressMapper().add_deployed_contract(target, contract_name) + addr_str = contract_name + except: + pass + initcode_str = f"<{byte_length(message.data)} bytes of initcode>" print( f"{indent}{call_scheme_str}{addr_str}::{initcode_str}{value_str}", file=file ) else: - calldata = rendered_calldata(message.data) + calldata = rendered_calldata(message.data, addr_str) call_str = f"{addr_str}::{calldata}" static_str = yellow(" [static]") if message.is_static else "" print(f"{indent}{call_scheme_str}{call_str}{static_str}{value_str}", file=file) @@ -376,7 +385,7 @@ def run_bytecode(hexcode: str, args: HalmosConfig) -> List[Exec]: print(f"Return data: {returndata}") dump_dirname = f"/tmp/halmos-{uuid.uuid4().hex}" model_with_context = gen_model_from_sexpr( - GenModelArgs(args, idx, ex.path.solver.to_smt2(), dump_dirname) + GenModelArgs(args, idx, ex.path.to_smt2(args), {}, dump_dirname) ) print(f"Input example: {model_with_context.model}") @@ -510,7 +519,7 @@ def setup( error = setup_ex.context.output.error if error is None: - setup_exs_no_error.append((setup_ex, setup_ex.path.solver.to_smt2())) + setup_exs_no_error.append((setup_ex, setup_ex.path.to_smt2(args))) else: if opcode not in [EVM.REVERT, EVM.INVALID]: @@ -531,7 +540,7 @@ def setup( if len(setup_exs_no_error) > 1: for setup_ex, query in setup_exs_no_error: - res, _ = solve(query, args) + res, _, _ = solve(query, args) if res != unsat: setup_exs.append(setup_ex) if len(setup_exs) > 1: @@ -584,6 +593,7 @@ class ModelWithContext: is_valid: Optional[bool] index: int result: CheckSatResult + unsat_core: Optional[List] @dataclass(frozen=True) @@ -695,6 +705,7 @@ def run( result_exs = [] future_models = [] counterexamples = [] + unsat_cores = [] traces = {} def future_callback(future_model): @@ -703,6 +714,8 @@ def future_callback(future_model): model, is_valid, index, result = m.model, m.is_valid, m.index, m.result if result == unsat: + if m.unsat_core: + unsat_cores.append(m.unsat_core) return # model could be an empty dict here @@ -736,7 +749,7 @@ def future_callback(future_model): if args.verbose >= VERBOSITY_TRACE_PATHS: print(f"Path #{idx+1}:") - print(indent_text(str(ex.path))) + print(indent_text(hexify(ex.path))) print("\nTrace:") render_trace(ex.context) @@ -753,10 +766,11 @@ def future_callback(future_model): if args.verbose >= VERBOSITY_TRACE_COUNTEREXAMPLE: traces[idx] = rendered_trace(ex.context) - query = ex.path.solver.to_smt2() + query = ex.path.to_smt2(args) future_model = thread_pool.submit( - gen_model_from_sexpr, GenModelArgs(args, idx, query, dump_dirname) + gen_model_from_sexpr, + GenModelArgs(args, idx, query, unsat_cores, dump_dirname), ) future_model.add_done_callback(future_callback) future_models.append(future_model) @@ -764,7 +778,7 @@ def future_callback(future_model): elif ex.context.is_stuck(): stuck.append((idx, ex, ex.context.get_stuck_reason())) if args.print_blocked_states: - traces[idx] = rendered_trace(ex.context) + traces[idx] = f"{hexify(ex.path)}\n{rendered_trace(ex.context)}" elif not error: normal += 1 @@ -1035,7 +1049,8 @@ def run_sequential(run_args: RunArgs) -> List[TestResult]: class GenModelArgs: args: HalmosConfig idx: int - sexpr: str + sexpr: SMTQuery + known_unsat_cores: List[List] dump_dirname: Optional[str] = None @@ -1043,19 +1058,53 @@ def copy_model(model: Model) -> Dict: return {decl: model[decl] for decl in model} +def parse_unsat_core(output) -> Optional[List]: + # parsing example: + # unsat + # (error "the context is unsatisfiable") + # (<41702> <37030> <36248> <47880>) + # result: + # [41702, 37030, 36248, 47880] + match = re.search(r"unsat\s*\(\s*error\s+[^)]*\)\s*\(\s*((<[0-9]+>\s*)*)\)", output) + if match: + result = [re.sub(r"<([0-9]+)>", r"\1", name) for name in match.group(1).split()] + return result + else: + warn(f"error in parsing unsat core: {output}") + return None + + def solve( - query: str, args: HalmosConfig, dump_filename: Optional[str] = None -) -> Tuple[CheckSatResult, Model]: + query: SMTQuery, args: HalmosConfig, dump_filename: Optional[str] = None +) -> Tuple[CheckSatResult, Model, Optional[List]]: if args.dump_smt_queries or args.solver_command: if not dump_filename: dump_filename = f"/tmp/{uuid.uuid4().hex}.smt2" + # for each implication assertion, `(assert (=> |id| c))`, in query.smtlib, + # generate a corresponding named assertion, `(assert (! |id| :named ))`. + # see `svem.Path.to_smt2()` for more details. + if args.cache_solver: + named_assertions = "".join( + [ + f"(assert (! |{assert_id}| :named <{assert_id}>))\n" + for assert_id in query.assertions + ] + ) + with open(dump_filename, "w") as f: if args.verbose >= 1: print(f"Writing SMT query to {dump_filename}") + if args.cache_solver: + f.write("(set-option :produce-unsat-cores true)\n") f.write("(set-logic QF_AUFBV)\n") - f.write(query) + f.write(query.smtlib) + if args.cache_solver: + f.write(named_assertions) + f.write("(check-sat)\n") f.write("(get-model)\n") + if args.cache_solver: + f.write("(get-unsat-core)\n") if args.solver_command: if args.verbose >= 1: @@ -1082,20 +1131,40 @@ def solve( print(f" {res_str_head}") if res_str_head == "unsat": - return unsat, None + unsat_core = parse_unsat_core(res_str) if args.cache_solver else None + return unsat, None, unsat_core elif res_str_head == "sat": - return sat, f"{dump_filename}.out" + return sat, f"{dump_filename}.out", None else: - return unknown, None + return unknown, None, None except subprocess.TimeoutExpired: - return unknown, None + return unknown, None, None else: - solver = mk_solver(args, ctx=Context(), assertion=True) - solver.from_string(query) - result = solver.check() + ctx = Context() + solver = mk_solver(args, ctx=ctx, assertion=True) + solver.from_string(query.smtlib) + if args.cache_solver: + solver.set(unsat_core=True) + ids = [Bool(f"{x}", ctx) for x in query.assertions] + result = solver.check(*ids) + else: + result = solver.check() model = copy_model(solver.model()) if result == sat else None - return result, model + unsat_core = ( + [str(core) for core in solver.unsat_core()] + if args.cache_solver and result == unsat + else None + ) + return result, model, unsat_core + + +def check_unsat_cores(query, unsat_cores) -> bool: + # return true if the given query contains any given unsat core + for unsat_core in unsat_cores: + if all(core in query.assertions for core in unsat_core): + return True + return False def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext: @@ -1111,46 +1180,55 @@ def gen_model_from_sexpr(fn_args: GenModelArgs) -> ModelWithContext: if args.verbose >= 1: print(f"Checking path condition (path id: {idx+1})") - res, model = solve(sexpr, args, dump_filename) + if check_unsat_cores(sexpr, fn_args.known_unsat_cores): + # if the given query contains an unsat-core, it is unsat; no need to run the solver. + if args.verbose >= 1: + print(" Already proven unsat") + return package_result(None, idx, unsat, None, args) + + res, model, unsat_core = solve(sexpr, args, dump_filename) if res == sat and not is_model_valid(model): if args.verbose >= 1: print(f" Checking again with refinement") refined_filename = dump_filename.replace(".smt2", ".refined.smt2") - res, model = solve(refine(sexpr), args, refined_filename) + res, model, unsat_core = solve(refine(sexpr), args, refined_filename) - return package_result(model, idx, res, args) + return package_result(model, idx, res, unsat_core, args) def is_unknown(result: CheckSatResult, model: Model) -> bool: return result == unknown or (result == sat and not is_model_valid(model)) -def refine(query: str) -> str: +def refine(query: SMTQuery) -> SMTQuery: + smtlib = query.smtlib # replace uninterpreted abstraction with actual symbols for assertion solving - # TODO: replace `(evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))` - # as bvudiv is undefined when y = 0; also similarly for evm_bvurem - query = re.sub(r"(\(\s*)evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", query) + # TODO: replace `(f_evm_bvudiv x y)` with `(ite (= y (_ bv0 256)) (_ bv0 256) (bvudiv x y))` + # as bvudiv is undefined when y = 0; also similarly for f_evm_bvurem + smtlib = re.sub(r"(\(\s*)f_evm_(bv[a-z]+)(_[0-9]+)?\b", r"\1\2", smtlib) # remove the uninterpreted function symbols # TODO: this will be no longer needed once is_model_valid is properly implemented - return re.sub( - r"\(\s*declare-fun\s+evm_(bv[a-z]+)(_[0-9]+)?\b", + smtlib = re.sub( + r"\(\s*declare-fun\s+f_evm_(bv[a-z]+)(_[0-9]+)?\b", r"(declare-fun dummy_\1\2", - query, + smtlib, ) + return SMTQuery(smtlib, query.assertions) def package_result( model: Optional[UnionType[Model, str]], idx: int, result: CheckSatResult, + unsat_core: Optional[List], args: HalmosConfig, ) -> ModelWithContext: if result == unsat: if args.verbose >= 1: print(f" Invalid path; ignored (path id: {idx+1})") - return ModelWithContext(None, None, idx, result) + return ModelWithContext(None, None, idx, result, unsat_core) if result == sat: if args.verbose >= 1: @@ -1166,30 +1244,30 @@ def package_result( is_valid = is_model_valid(model) model = to_str_model(model, args.print_full_model) - return ModelWithContext(model, is_valid, idx, result) + return ModelWithContext(model, is_valid, idx, result, None) else: if args.verbose >= 1: print(f" Timeout (path id: {idx+1})") - return ModelWithContext(None, None, idx, result) + return ModelWithContext(None, None, idx, result, None) def is_model_valid(model: AnyModel) -> bool: - # TODO: evaluate the path condition against the given model after excluding evm_* symbols, - # since the evm_* symbols may still appear in valid models. + # TODO: evaluate the path condition against the given model after excluding f_evm_* symbols, + # since the f_evm_* symbols may still appear in valid models. # model is a filename, containing solver output if isinstance(model, str): with open(model, "r") as f: for line in f: - if "evm_" in line: + if "f_evm_" in line: return False return True # z3 model object else: for decl in model: - if str(decl).startswith("evm_"): + if str(decl).startswith("f_evm_"): return False return True @@ -1279,6 +1357,29 @@ def parse_build_out(args: HalmosConfig) -> Dict: sol_dirname, ) contract_map[contract_name] = (json_out, contract_type, natspec) + + try: + bytecode = contract_map[contract_name][0]["bytecode"]["object"] + contract_mapping_info = Mapper().get_contract_mapping_info_by_name( + contract_name + ) + + if contract_mapping_info is None: + Mapper().add_contract_mapping_info( + contract_name=contract_name, + bytecode=bytecode, + nodes=[], + ) + else: + contract_mapping_info.bytecode = bytecode + + contract_mapping_info = Mapper().get_contract_mapping_info_by_name( + contract_name + ) + Mapper().parse_ast(contract_map[contract_name][0]["ast"]) + + except Exception: + pass except Exception as err: warn_code( PARSING_ERROR, @@ -1502,6 +1603,9 @@ def on_signal(signum, frame): contract_path = f"{contract_json['ast']['absolutePath']}:{contract_name}" print(f"\nRunning {num_found} tests for {contract_path}") + # Set 0xaaaa0001 in DeployAddressMapper + DeployAddressMapper().add_deployed_contract("0xaaaa0001", contract_name) + # support for `/// @custom:halmos` annotations contract_args = with_natspec(args, contract_name, natspec) run_args = RunArgs( diff --git a/src/halmos/cheatcodes.py b/src/halmos/cheatcodes.py index df967c7e..6dcdaa3f 100644 --- a/src/halmos/cheatcodes.py +++ b/src/halmos/cheatcodes.py @@ -9,7 +9,7 @@ from z3 import * from .bytevec import ByteVec -from .exceptions import FailCheatcode, HalmosException +from .exceptions import FailCheatcode, HalmosException, InfeasiblePath from .utils import * @@ -337,6 +337,8 @@ def handle(sevm, ex, arg: ByteVec) -> Optional[ByteVec]: # vm.assume(bool) if funsig == hevm_cheat_code.assume_sig: assume_cond = simplify(is_non_zero(arg.get_word(4))) + if is_false(assume_cond): + raise InfeasiblePath("vm.assume(false)") ex.path.append(assume_cond) return ret diff --git a/src/halmos/config.py b/src/halmos/config.py index 5efd21fa..1c720de3 100644 --- a/src/halmos/config.py +++ b/src/halmos/config.py @@ -13,11 +13,12 @@ internal = "internal" # groups -debugging, solver, build, experimental = ( +debugging, solver, build, experimental, deprecated = ( "Debugging options", "Solver options", "Build options", "Experimental options", + "Deprecated options", ) @@ -343,10 +344,6 @@ class Config: group=solver, ) - solver_parallel: bool = arg( - help="run assertion solvers in parallel", global_default=False, group=solver - ) - solver_threads: int = arg( help="set the number of threads for parallel solvers", metavar="N", @@ -355,6 +352,10 @@ class Config: global_default_str="number of CPUs", ) + cache_solver: bool = arg( + help="cache unsat queries using unsat cores", global_default=False, group=solver + ) + ### Experimental options bytecode: str = arg( @@ -381,6 +382,14 @@ class Config: group=experimental, ) + ### Deprecated + + solver_parallel: bool = arg( + help="(Deprecated; no-op; use --solver-threads instead) run assertion solvers in parallel", + global_default=False, + group=deprecated, + ) + ### Methods def __getattribute__(self, name): diff --git a/src/halmos/exceptions.py b/src/halmos/exceptions.py index a8982f6e..6516c1f5 100644 --- a/src/halmos/exceptions.py +++ b/src/halmos/exceptions.py @@ -8,7 +8,22 @@ """ -class HalmosException(Exception): +class PathEndingException(Exception): + """ + Base class for any exception that should stop the current path exploration. + + Stopping path exploration means stopping not only the current EVM context but also its parent contexts if any. + """ + + pass + + +class HalmosException(PathEndingException): + """ + Base class for unexpected internal errors happening during a test run. + Inherits from RunEndingException because it should stop further path exploration. + """ + pass @@ -16,6 +31,23 @@ class NotConcreteError(HalmosException): pass +class InfeasiblePath(PathEndingException): + """ + Raise when the current path condition turns out to be infeasible. + """ + + pass + + +class FailCheatcode(PathEndingException): + """ + Raised when invoking DSTest's fail() pseudo-cheatcode. + Inherits from RunEndingException because it should stop further path exploration. + """ + + pass + + class EvmException(Exception): """ Base class for all EVM exceptions. @@ -132,14 +164,6 @@ class InvalidContractPrefix(ExceptionalHalt): pass -class FailCheatcode(ExceptionalHalt): - """ - Raised when invoking hevm's vm.fail() cheatcode - """ - - pass - - class AddressCollision(ExceptionalHalt): """ Raised when trying to deploy into a non-empty address diff --git a/src/halmos/mapper.py b/src/halmos/mapper.py new file mode 100644 index 00000000..fa59fdc9 --- /dev/null +++ b/src/halmos/mapper.py @@ -0,0 +1,194 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + + +@dataclass +class AstNode: + node_type: str + id: int + name: str + address: str # TODO: rename it to `selector` or `signature` to better reflect the meaning + visibility: str + + +@dataclass +class ContractMappingInfo: + contract_name: str + bytecode: str + nodes: List[AstNode] + + +class SingletonMeta(type): + _instances: Dict[Type, Any] = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + + return cls._instances[cls] + + +class Mapper(metaclass=SingletonMeta): + """ + Mapping from a contract name to its runtime bytecode and the signatures of functions/events/errors declared in the contract + """ + + _PARSING_IGNORED_NODE_TYPES = [ + "StructDefinition", + "EnumDefinition", + "PragmaDirective", + "ImportDirective", + "Block", + ] + + def __init__(self): + self._contracts: Dict[str, ContractMappingInfo] = {} + + def add_contract_mapping_info( + self, contract_name: str, bytecode: str, nodes: List[AstNode] + ): + if contract_name in self._contracts: + raise ValueError(f"Contract {contract_name} already exists") + + self._contracts[contract_name] = ContractMappingInfo( + contract_name, bytecode, nodes + ) + + def get_contract_mapping_info_by_name( + self, contract_name: str + ) -> Optional[ContractMappingInfo]: + return self._contracts.get(contract_name, None) + + def get_contract_mapping_info_by_bytecode( + self, bytecode: str + ) -> Optional[ContractMappingInfo]: + # TODO: Handle cases for contracts with immutable variables + # Current implementation might not work correctly if the following code is added the test solidity file + # + # address immutable public owner; + # constructor() { + # owner = msg.sender; + # } + + for contract_mapping_info in self._contracts.values(): + # TODO: use regex instaed of `endswith` to better handle immutables or constructors with arguments + if contract_mapping_info.bytecode.endswith(bytecode): + return contract_mapping_info + + return None + + def append_node(self, contract_name: str, node: AstNode): + contract_mapping_info = self.get_contract_mapping_info_by_name(contract_name) + + if contract_mapping_info is None: + raise ValueError(f"Contract {contract_name} not found") + + contract_mapping_info.nodes.append(node) + + def parse_ast(self, node: Dict, contract_name: str = ""): + node_type = node["nodeType"] + + if node_type in self._PARSING_IGNORED_NODE_TYPES: + return + + current_contract = self._get_current_contract(node, contract_name) + + if node_type == "ContractDefinition": + if current_contract not in self._contracts: + self.add_contract_mapping_info( + contract_name=current_contract, bytecode="", nodes=[] + ) + + if self.get_contract_mapping_info_by_name(current_contract).nodes: + return + elif node_type != "SourceUnit": + id, name, address, visibility = self._get_node_info(node, node_type) + + self.append_node( + current_contract, + AstNode(node_type, id, name, address, visibility), + ) + + for child_node in node.get("nodes", []): + self.parse_ast(child_node, current_contract) + + if "body" in node: + self.parse_ast(node["body"], current_contract) + + def _get_node_info(self, node: Dict, node_type: str) -> Dict: + return ( + node.get("id", ""), + node.get("name", ""), + "0x" + self._get_node_address(node, node_type), + node.get("visibility", ""), + ) + + def _get_node_address(self, node: Dict, node_type: str) -> str: + address_fields = { + "VariableDeclaration": "functionSelector", + "FunctionDefinition": "functionSelector", + "EventDefinition": "eventSelector", + "ErrorDefinition": "errorSelector", + } + + return node.get(address_fields.get(node_type, ""), "") + + def _get_current_contract(self, node: Dict, contract_name: str) -> str: + return ( + node.get("name", "") + if node["nodeType"] == "ContractDefinition" + else contract_name + ) + + def find_nodes_by_address(self, address: str, contract_name: str = None): + # if the given signature is declared in the given contract, return its name. + if contract_name: + contract_mapping_info = self.get_contract_mapping_info_by_name( + contract_name + ) + + if contract_mapping_info: + for node in contract_mapping_info.nodes: + if node.address == address: + return node.name + + # otherwise, search for the signature in other contracts, and return all the contracts that declare it. + # note: ambiguity may occur if multiple compilation units exist. + result = "" + for key, contract_info in self._contracts.items(): + matching_nodes = [ + node for node in contract_info.nodes if node.address == address + ] + + for node in matching_nodes: + result += f"{key}.{node.name} " + + return result.strip() if result != "" and address != "0x" else address + + +# TODO: create a new instance or reset for each test +class DeployAddressMapper(metaclass=SingletonMeta): + """ + Mapping from deployed addresses to contract names + """ + + def __init__(self): + self._deployed_contracts: Dict[str, str] = {} + + # Set up some default mappings + self.add_deployed_contract( + "0x7109709ecfa91a80626ff3989d68f67f5b1dd12d", "HEVM_ADDRESS" + ) + self.add_deployed_contract( + "0xf3993a62377bcd56ae39d773740a5390411e8bc9", "SVM_ADDRESS" + ) + + def add_deployed_contract( + self, + address: str, + contract_name: str, + ): + self._deployed_contracts[address] = contract_name + + def get_deployed_contract(self, address: str) -> Optional[str]: + return self._deployed_contracts.get(address, address) diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index e6d7be72..4e2b98cf 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -31,6 +31,7 @@ from .warnings import ( warn_code, LIBRARY_PLACEHOLDER, + INTERNAL_ERROR, ) Steps = Dict[int, Dict[str, Any]] # execution tree @@ -44,32 +45,32 @@ # symbolic states # calldataload(index) -f_calldataload = Function("calldataload", BitVecSort256, BitVecSort256) +f_calldataload = Function("f_calldataload", BitVecSort256, BitVecSort256) # calldatasize() -f_calldatasize = Function("calldatasize", BitVecSort256) +f_calldatasize = Function("f_calldatasize", BitVecSort256) # extcodesize(target address) -f_extcodesize = Function("extcodesize", BitVecSort160, BitVecSort256) +f_extcodesize = Function("f_extcodesize", BitVecSort160, BitVecSort256) # extcodehash(target address) -f_extcodehash = Function("extcodehash", BitVecSort160, BitVecSort256) +f_extcodehash = Function("f_extcodehash", BitVecSort160, BitVecSort256) # blockhash(block number) -f_blockhash = Function("blockhash", BitVecSort256, BitVecSort256) +f_blockhash = Function("f_blockhash", BitVecSort256, BitVecSort256) # gas(cnt) -f_gas = Function("gas", BitVecSort256, BitVecSort256) +f_gas = Function("f_gas", BitVecSort256, BitVecSort256) # gasprice() -f_gasprice = Function("gasprice", BitVecSort256) +f_gasprice = Function("f_gasprice", BitVecSort256) # origin() -f_origin = Function("origin", BitVecSort160) +f_origin = Function("f_origin", BitVecSort160) # uninterpreted arithmetic -f_div = Function("evm_bvudiv", BitVecSort256, BitVecSort256, BitVecSort256) +f_div = Function("f_evm_bvudiv", BitVecSort256, BitVecSort256, BitVecSort256) f_mod = { - 256: Function("evm_bvurem", BitVecSort256, BitVecSort256, BitVecSort256), - 264: Function("evm_bvurem_264", BitVecSort264, BitVecSort264, BitVecSort264), - 512: Function("evm_bvurem_512", BitVecSort512, BitVecSort512, BitVecSort512), + 256: Function("f_evm_bvurem", BitVecSort256, BitVecSort256, BitVecSort256), + 264: Function("f_evm_bvurem_264", BitVecSort264, BitVecSort264, BitVecSort264), + 512: Function("f_evm_bvurem_512", BitVecSort512, BitVecSort512, BitVecSort512), } -f_sdiv = Function("evm_bvsdiv", BitVecSort256, BitVecSort256, BitVecSort256) -f_smod = Function("evm_bvsrem", BitVecSort256, BitVecSort256, BitVecSort256) -f_exp = Function("evm_exp", BitVecSort256, BitVecSort256, BitVecSort256) +f_sdiv = Function("f_evm_bvsdiv", BitVecSort256, BitVecSort256, BitVecSort256) +f_smod = Function("f_evm_bvsrem", BitVecSort256, BitVecSort256, BitVecSort256) +f_exp = Function("f_evm_exp", BitVecSort256, BitVecSort256, BitVecSort256) magic_address: int = 0xAAAA0000 @@ -470,6 +471,12 @@ def __next__(self) -> Tuple[int, int]: raise StopIteration +@dataclass(frozen=True) +class SMTQuery: + smtlib: str + assertions: List # list of assertion ids + + class Path: # a Path object represents a prefix of the path currently being executed # initially, it's an empty path at the beginning of execution @@ -477,28 +484,68 @@ class Path: solver: Solver num_scopes: int # path constraints include both explicit branching conditions and implicit assumptions (eg, no hash collisions) - # TODO: separate these two types of constraints, so that we can display only branching conditions to users - conditions: List - branching: List # indexes of conditions + conditions: Dict # cond -> bool (true if explicit branching conditions) pending: List def __init__(self, solver: Solver): self.solver = solver self.num_scopes = 0 - self.conditions = [] - self.branching = [] + self.conditions = {} self.pending = [] - self.forked = False def __deepcopy__(self, memo): raise NotImplementedError(f"use the branch() method instead of deepcopy()") def __str__(self) -> str: - branching_conds = [self.conditions[idx] for idx in self.branching] return "".join( - [f"- {cond}\n" for cond in branching_conds if str(cond) != "True"] + [ + f"- {cond}\n" + for cond in self.conditions + if self.conditions[cond] and str(cond) != "True" + ] ) + def to_smt2(self, args) -> SMTQuery: + # Serialize self.conditions into the SMTLIB format. + # + # Each `c` in the conditions can be serialized to an SMTLIB assertion: + # `(assert c)` + # + # To compute the unsat-core later, a named assertion is needed: + # `(assert (! c :named id))` where `id` is the unique id of `c` + # + # However, z3.Solver.to_smt2() doesn't serialize into named assertions. Instead, + # - `Solver.add(c)` is serialized as: `(assert c)` + # - `Solver.assert_and_track(c, id)` is serialized as: `(assert (=> |id| c))` + # + # Thus, named assertions can be generated using `to_smt2()` as follows: + # - add constraints using `assert_and_track(c, id)` for each c and id, + # - execute `to_smt2()` to generate implication assertions, `(assert (=> |id| c))`, and + # - generate named assertions, `(assert (! |id| :named ))`, for each id. + # + # The first two steps are performed here. The last step is done in `__main__.solve()`. + # + # NOTE: although both `to_smt2()` and `sexpr()` can generate SMTLIB assertions, + # sexpr()-generated SMTLIB queries are often less efficient to solve than to_smt2(). + # + # TODO: leverage more efficient serialization by representing constraints in pickle-friendly objects, instead of Z3 objects. + + ids = [str(cond.get_id()) for cond in self.conditions] + + if args.cache_solver: + tmp_solver = SolverFor("QF_AUFBV") + for cond in self.conditions: + tmp_solver.assert_and_track(cond, str(cond.get_id())) + query = tmp_solver.to_smt2() + else: + query = self.solver.to_smt2() + query = query.replace("(check-sat)", "") # see __main__.solve() + + return SMTQuery(query, ids) + + def check(self, cond): + return self.solver.check(cond) + def branch(self, cond): if len(self.pending) > 0: raise ValueError("branching from an inactive path", self) @@ -541,17 +588,16 @@ def activate(self): def append(self, cond, branching=False): cond = simplify(cond) - if self.forked: - warn(f"attempting to append cond {cond} to forked path {id(self)}") - if is_true(cond): return - self.solver.add(cond) - self.conditions.append(cond) + if is_false(cond): + # false shouldn't have been added; raise InfeasiblePath before append() if false + warn_code(INTERNAL_ERROR, f"path.append(false)") - if branching: - self.branching.append(len(self.conditions) - 1) + if cond not in self.conditions: + self.solver.add(cond) + self.conditions[cond] = branching def extend(self, conds, branching=False): for cond in conds: @@ -559,7 +605,7 @@ def extend(self, conds, branching=False): def extend_path(self, path): # branching conditions are not preserved - self.extend(path.conditions) + self.extend(path.conditions.keys()) class Exec: # an execution path @@ -737,7 +783,7 @@ def check(self, cond: Any) -> Any: if is_false(cond): return unsat - return self.path.solver.check(cond) + return self.path.check(cond) def select(self, array: Any, key: Word, arrays: Dict) -> Word: if array in arrays: @@ -790,10 +836,12 @@ def sha3_data(self, data: Bytes) -> Word: if isinstance(data, bytes): data = bytes_to_bv_value(data) - f_sha3 = Function(f"sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256) + f_sha3 = Function( + f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256 + ) sha3_expr = f_sha3(data) else: - sha3_expr = BitVec("sha3_0", BitVecSort256) + sha3_expr = BitVec("f_sha3_0", BitVecSort256) # assume hash values are sufficiently smaller than the uint max self.path.append(ULE(sha3_expr, 2**256 - 2**64)) @@ -983,17 +1031,17 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: def decode(cls, loc: Any) -> Any: loc = normalize(loc) # m[k] : hash(k.m) - if loc.decl().name() == "sha3_512": + if loc.decl().name() == "f_sha3_512": args = loc.arg(0) offset = simplify(Extract(511, 256, args)) base = simplify(Extract(255, 0, args)) return cls.decode(base) + (offset, con(0)) # a[i] : hash(a) + i - elif loc.decl().name() == "sha3_256": + elif loc.decl().name() == "f_sha3_256": base = loc.arg(0) return cls.decode(base) + (con(0),) # m[k] : hash(k.m) where |k| != 256-bit - elif loc.decl().name().startswith("sha3_"): + elif loc.decl().name().startswith("f_sha3_"): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2: offset = simplify(sha3_input.arg(0)) @@ -1084,12 +1132,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @classmethod def decode(cls, loc: Any) -> Any: loc = normalize(loc) - if loc.decl().name() == "sha3_512": # hash(hi,lo), recursively + if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively args = loc.arg(0) hi = cls.decode(simplify(Extract(511, 256, args))) lo = cls.decode(simplify(Extract(255, 0, args))) return cls.simple_hash(Concat(hi, lo)) - elif loc.decl().name().startswith("sha3_"): + elif loc.decl().name().startswith("f_sha3_"): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat": decoded_sha3_input_args = [ @@ -1472,6 +1520,8 @@ def transfer_value( # assume balance is enough; otherwise ignore this path # note: evm requires enough balance even for self-transfer balance_cond = simplify(UGE(ex.balance_of(caller), value)) + if is_false(balance_cond): + raise InfeasiblePath("transfer_value: balance is not enough") ex.path.append(balance_cond) # conditional transfer @@ -1618,7 +1668,7 @@ def call_unknown() -> None: if arg_size > 0: f_call = Function( - "call_" + str(arg_size * 8), + "f_call_" + str(arg_size * 8), BitVecSort256, # cnt BitVecSort256, # gas BitVecSort160, # to @@ -1633,7 +1683,7 @@ def call_unknown() -> None: exit_code = f_call(con(call_id), gas, to, fund, arg_bv) else: f_call = Function( - "call_" + str(arg_size * 8), + "f_call_" + str(arg_size * 8), BitVecSort256, # cnt BitVecSort256, # gas BitVecSort160, # to @@ -1653,7 +1703,7 @@ def call_unknown() -> None: ret = ByteVec() if actual_ret_size > 0: f_ret = Function( - "ret_" + str(actual_ret_size * 8), + "f_ret_" + str(actual_ret_size * 8), BitVecSort256, BitVecSorts[actual_ret_size * 8], ) @@ -2477,6 +2527,10 @@ def finalize(ex: Exec): ex.next_pc() stack.push(ex, step_id) + except InfeasiblePath as err: + # ignore infeasible path + continue + except EvmException as err: ex.halt(data=ByteVec(), error=err) yield from finalize(ex) @@ -2490,6 +2544,12 @@ def finalize(ex: Exec): yield from finalize(ex) continue + except FailCheatcode as err: + # return data shouldn't be None, as it is considered being stuck + ex.halt(data=ByteVec(), error=err) + yield ex # early exit; do not call finalize() + continue + def mk_exec( self, # diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 6c444f6f..0331a632 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: AGPL-3.0 import re - +from functools import partial from timeit import default_timer as timer -from typing import Dict, Tuple, Any, Optional, Union as UnionType +from typing import Any, Dict, Optional, Tuple +from typing import Union as UnionType from z3 import * -from .exceptions import NotConcreteError, HalmosException +from halmos.mapper import Mapper + +from .exceptions import HalmosException, NotConcreteError # order of the secp256k1 curve secp256k1n = ( @@ -61,7 +64,12 @@ def __getitem__(self, size: int) -> BitVecSort: # ecrecover(digest, v, r, s) f_ecrecover = Function( - "ecrecover", BitVecSort256, BitVecSort8, BitVecSort256, BitVecSort256, BitVecSort160 + "f_ecrecover", + BitVecSort256, + BitVecSort8, + BitVecSort256, + BitVecSort256, + BitVecSort160, ) @@ -313,23 +321,25 @@ def decode_hex(hexstring: str) -> Optional[bytes]: return None -def hexify(x): +def hexify(x, contract_name: str = None): if isinstance(x, str): return re.sub(r"\b(\d+)\b", lambda match: hex(int(match.group(1))), x) elif isinstance(x, int): return f"0x{x:02x}" elif isinstance(x, bytes): - return "0x" + x.hex() + return Mapper().find_nodes_by_address("0x" + x.hex(), contract_name) elif hasattr(x, "unwrap"): - return hexify(x.unwrap()) + return hexify(x.unwrap(), contract_name) elif is_bv_value(x): # maintain the byte size of x num_bytes = byte_length(x, strict=False) - return f"0x{x.as_long():0{num_bytes * 2}x}" + return Mapper().find_nodes_by_address( + f"0x{x.as_long():0{num_bytes * 2}x}", contract_name + ) elif is_app(x): - return f"{str(x.decl())}({', '.join(map(hexify, x.children()))})" + return f"{str(x.decl())}({', '.join(map(partial(hexify, contract_name=contract_name), x.children()))})" else: - return hexify(str(x)) + return hexify(str(x), contract_name) def render_uint(x: BitVecRef) -> str: diff --git a/tests/expected/all.json b/tests/expected/all.json index dd4fe9a3..255fc44e 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -427,7 +427,7 @@ { "name": "check_call1_fail(uint256)", "exitcode": 1, - "num_models": 2, + "num_models": 1, "models": null, "num_paths": null, "time": null, @@ -766,6 +766,28 @@ "num_bounded_loops": null } ], + "test/Foundry.t.sol:DeepFailer": [ + { + "name": "check_fail_cheatcode()", + "exitcode": 1, + "num_models": 1, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + } + ], + "test/Foundry.t.sol:EarlyFailTest": [ + { + "name": "check_early_fail_cheatcode(uint256)", + "exitcode": 1, + "num_models": 1, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + } + ], "test/Foundry.t.sol:FoundryTest": [ { "name": "check_assume(uint256)", diff --git a/tests/ffi/foundry.toml b/tests/ffi/foundry.toml index 3209ea22..57071666 100644 --- a/tests/ffi/foundry.toml +++ b/tests/ffi/foundry.toml @@ -3,9 +3,8 @@ src = 'src' out = 'out' libs = ['../lib', 'lib'] -# See more config options https://github.com/foundry-rs/foundry/tree/master/config -force = false -evm_version = 'shanghai' +evm_version = 'cancun' +force = false # compile options used by halmos (to prevent unnecessary recompilation when running forge test and halmos together) extra_output = ["storageLayout", "metadata"] diff --git a/tests/regression/foundry.toml b/tests/regression/foundry.toml index 3209ea22..57071666 100644 --- a/tests/regression/foundry.toml +++ b/tests/regression/foundry.toml @@ -3,9 +3,8 @@ src = 'src' out = 'out' libs = ['../lib', 'lib'] -# See more config options https://github.com/foundry-rs/foundry/tree/master/config -force = false -evm_version = 'shanghai' +evm_version = 'cancun' +force = false # compile options used by halmos (to prevent unnecessary recompilation when running forge test and halmos together) extra_output = ["storageLayout", "metadata"] diff --git a/tests/regression/test/Foundry.t.sol b/tests/regression/test/Foundry.t.sol index 3cb08f44..a9cc4202 100644 --- a/tests/regression/test/Foundry.t.sol +++ b/tests/regression/test/Foundry.t.sol @@ -16,11 +16,29 @@ contract DeepFailer is Test { } } - function test_fail_cheatcode() public { + function check_fail_cheatcode() public { DeepFailer(address(this)).do_test(0); } } +contract EarlyFailTest is Test { + function do_fail() external { + fail(); + } + + function check_early_fail_cheatcode(uint x) public { + // we want `fail()` to happen in a nested context, + // to test that it ends not just the current context but the whole run + address(this).call(abi.encodeWithSelector(this.do_fail.selector, "")); + + // this shouldn't be reached due to the early fail() semantics. + // if this assertion is executed, two counterexamples will be generated: + // - counterexample caused by fail(): x > 0 + // - counterexample caused by assert(x > 0): x == 0 + assert(x > 0); + } +} + contract FoundryTest is Test { /* TODO: support checkFail prefix function checkFail() public { diff --git a/tests/solver/foundry.toml b/tests/solver/foundry.toml index 3209ea22..57071666 100644 --- a/tests/solver/foundry.toml +++ b/tests/solver/foundry.toml @@ -3,9 +3,8 @@ src = 'src' out = 'out' libs = ['../lib', 'lib'] -# See more config options https://github.com/foundry-rs/foundry/tree/master/config -force = false -evm_version = 'shanghai' +evm_version = 'cancun' +force = false # compile options used by halmos (to prevent unnecessary recompilation when running forge test and halmos together) extra_output = ["storageLayout", "metadata"] diff --git a/tests/test_mapper.py b/tests/test_mapper.py new file mode 100644 index 00000000..ce7f1fba --- /dev/null +++ b/tests/test_mapper.py @@ -0,0 +1,176 @@ +from typing import List + +import pytest + +from halmos.mapper import AstNode, ContractMappingInfo, Mapper, SingletonMeta + + +@pytest.fixture +def ast_nodes() -> List[AstNode]: + return [ + AstNode( + node_type="type1", id=1, name="Node1", address="0x123", visibility="public" + ), + AstNode( + node_type="type2", id=2, name="Node2", address="0x456", visibility="private" + ), + ] + + +@pytest.fixture +def mapper() -> Mapper: + return Mapper() + + +@pytest.fixture(autouse=True) +def reset_singleton(): + SingletonMeta._instances = {} + + +def test_singleton(): + mapper1 = Mapper() + mapper2 = Mapper() + assert mapper1 is mapper2 + + +def test_add_contract_mapping_info(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert contract_info.contract_name == "ContractA" + assert contract_info.bytecode == "bytecodeA" + assert len(contract_info.nodes) == 2 + + +def test_add_contract_mapping_info_already_existence(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + + with pytest.raises(ValueError, match=r"Contract ContractA already exists"): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + + +def test_get_contract_mapping_info_by_name(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert contract_info.contract_name == "ContractA" + + +def test_get_contract_mapping_info_by_name_nonexistent(mapper): + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is None + + +def test_get_contract_mapping_info_by_bytecode(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + contract_info = mapper.get_contract_mapping_info_by_bytecode("bytecodeA") + assert contract_info is not None + assert contract_info.bytecode == "bytecodeA" + + +def test_get_contract_mapping_info_by_bytecode_nonexistent(mapper): + contract_info = mapper.get_contract_mapping_info_by_bytecode("bytecodeA") + assert contract_info is None + + +def test_append_node(mapper, ast_nodes): + mapper.add_contract_mapping_info("ContractA", "bytecodeA", ast_nodes) + new_node = AstNode( + node_type="type3", id=3, name="Node3", address="0x789", visibility="public" + ) + mapper.append_node("ContractA", new_node) + contract_info = mapper.get_contract_mapping_info_by_name("ContractA") + assert contract_info is not None + assert len(contract_info.nodes) == 3 + assert contract_info.nodes[-1].id == 3 + + +def test_append_node_to_nonexistent_contract(mapper): + new_node = AstNode( + node_type="type3", id=3, name="Node3", address="0x789", visibility="public" + ) + with pytest.raises(ValueError, match=r"Contract NonexistentContract not found"): + mapper.append_node("NonexistentContract", new_node) + + +def test_parse_simple_ast(mapper): + example_ast = { + "nodeType": "ContractDefinition", + "id": 1, + "name": "ExampleContract", + "nodes": [ + { + "nodeType": "FunctionDefinition", + "id": 2, + "name": "exampleFunction", + "functionSelector": "abcdef", + "visibility": "public", + "nodes": [], + } + ], + } + mapper.parse_ast(example_ast) + contract_info = mapper.get_contract_mapping_info_by_name("ExampleContract") + + assert contract_info is not None + assert contract_info.contract_name == "ExampleContract" + assert len(contract_info.nodes) == 1 + assert contract_info.nodes[0].name == "exampleFunction" + + +def test_parse_complex_ast(mapper): + complex_ast = { + "nodeType": "ContractDefinition", + "id": 1, + "name": "ComplexContract", + "nodes": [ + { + "nodeType": "VariableDeclaration", + "id": 2, + "name": "var1", + "functionSelector": "", + "visibility": "private", + }, + { + "nodeType": "FunctionDefinition", + "id": 3, + "name": "func1", + "functionSelector": "222222", + "visibility": "public", + "nodes": [ + { + "nodeType": "Block", + "id": 4, + "name": "innerBlock", + "functionSelector": "", + "visibility": "", + } + ], + }, + { + "nodeType": "EventDefinition", + "id": 5, + "name": "event1", + "eventSelector": "444444", + "visibility": "public", + }, + { + "nodeType": "ErrorDefinition", + "id": 6, + "name": "error1", + "errorSelector": "555555", + "visibility": "public", + }, + ], + } + mapper.parse_ast(complex_ast) + contract_info = mapper.get_contract_mapping_info_by_name("ComplexContract") + assert contract_info is not None + assert contract_info.contract_name == "ComplexContract" + assert len(contract_info.nodes) == 4 + + node_names = [node.name for node in contract_info.nodes] + assert "var1" in node_names + assert "func1" in node_names + assert "event1" in node_names + assert "error1" in node_names