diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index cfd7da4d..1e9cd617 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -24,6 +24,7 @@ BitVec, BitVecRef, BoolVal, + CheckSatResult, Concat, Extract, Function, @@ -94,14 +95,19 @@ debug, extract_bytes, f_ecrecover, + f_sha3_256_name, + f_sha3_512_name, + f_sha3_name, hexify, int_of, is_bool, is_bv, is_bv_value, is_concrete, + is_f_sha3_name, is_non_zero, is_zero, + match_dynamic_array_overflow_condition, restore_precomputed_hashes, sha3_inv, str_opcode, @@ -994,15 +1000,35 @@ def dump(self, print_mem=False) -> str: def advance_pc(self) -> None: self.pc = self.pgm.next_pc(self.pc) - def check(self, cond: Any) -> Any: - cond = simplify(cond) + def quick_custom_check(self, cond: BitVecRef) -> CheckSatResult | None: + """ + Quick custom checker for specific known patterns. + + This method checks for certain common conditions that can be evaluated + quickly without invoking the full SMT solver. + Returns: + sat if the condition is satisfiable + unsat if the condition is unsatisfiable + None if the condition requires full SMT solving + """ if is_true(cond): return sat if is_false(cond): return unsat + # Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64 + if match_dynamic_array_overflow_condition(cond): + return unsat + + def check(self, cond: Any) -> Any: + cond = simplify(cond) + + # use quick custom checker for common patterns before falling back to SMT solver + if result := self.quick_custom_check(cond): + return result + return self.path.check(cond) def select( @@ -1063,7 +1089,7 @@ def sha3_data(self, data: Bytes) -> Word: data = bytes_to_bv_value(data) f_sha3 = Function( - f"f_sha3_{size * 8}", BitVecSorts[size * 8], BitVecSort256 + f_sha3_name(size * 8), BitVecSorts[size * 8], BitVecSort256 ) sha3_expr = f_sha3(data) else: @@ -1288,17 +1314,17 @@ def get_key_structure(cls, loc) -> tuple: def decode(cls, loc: Any) -> Any: loc = normalize(loc) # m[k] : hash(k.m) - if loc.decl().name() == "f_sha3_512": + if loc.decl().name() == f_sha3_512_name: args = loc.arg(0) offset = simplify(Extract(511, 256, args)) base = simplify(Extract(255, 0, args)) return cls.decode(base) + (offset, ZERO) # a[i] : hash(a) + i - elif loc.decl().name() == "f_sha3_256": + elif loc.decl().name() == f_sha3_256_name: base = loc.arg(0) return cls.decode(base) + (ZERO,) # m[k] : hash(k.m) where |k| != 256-bit - elif loc.decl().name().startswith("f_sha3_"): + elif is_f_sha3_name(loc.decl().name()): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2: offset = simplify(sha3_input.arg(0)) @@ -1417,12 +1443,12 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @classmethod def decode(cls, loc: Any) -> Any: loc = normalize(loc) - if loc.decl().name() == "f_sha3_512": # hash(hi,lo), recursively + if loc.decl().name() == f_sha3_512_name: # hash(hi,lo), recursively args = loc.arg(0) hi = cls.decode(simplify(Extract(511, 256, args))) lo = cls.decode(simplify(Extract(255, 0, args))) return cls.simple_hash(Concat(hi, lo)) - elif loc.decl().name().startswith("f_sha3_"): + elif is_f_sha3_name(loc.decl().name()): sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat": decoded_sha3_input_args = [ @@ -2359,6 +2385,12 @@ def jumpi( follow_false = visited[False] < self.options.loop if not (follow_true and follow_false): self.logs.bounded_loops.append(jid) + if self.options.debug: + debug(f"\nloop id: {jid}") + debug(f"loop condition: {cond}") + debug(f"calldata: {ex.calldata()}") + debug("path condition:") + debug(ex.path) else: # for constant-bounded loops follow_true = potential_true diff --git a/src/halmos/utils.py b/src/halmos/utils.py index 47469bbe..05a13f7f 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -7,7 +7,9 @@ from typing import Any from z3 import ( + Z3_OP_BADD, Z3_OP_CONCAT, + Z3_OP_ULEQ, BitVecNumRef, BitVecRef, BitVecSort, @@ -21,11 +23,13 @@ SignExt, SolverFor, ZeroExt, + eq, is_app, is_app_of, is_bool, is_bv, is_bv_value, + is_not, simplify, ) @@ -94,6 +98,18 @@ def __getitem__(self, size: int) -> BitVecSort: ) +def is_f_sha3_name(name: str) -> bool: + return name.startswith("f_sha3_") + + +def f_sha3_name(bitsize: int) -> str: + return f"f_sha3_{bitsize}" + + +f_sha3_256_name = f_sha3_name(256) +f_sha3_512_name = f_sha3_name(512) + + def wrap(x: Any) -> Word: if is_bv(x): return x @@ -349,6 +365,40 @@ def byte_length(x: Any, strict=True) -> int: raise TypeError(f"byte_length({x}) of type {type(x)}") +def match_dynamic_array_overflow_condition(cond: BitVecRef) -> bool: + """ + Check if `cond` matches the following pattern: + Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))), where offset < 2**64 + + This condition is satisfied when a dynamic array at `slot` exceeds the storage limit. + Since such an overflow is highly unlikely in practice, we assume that this condition is unsat. + + Note: we already assume that any sha3 hash output is smaller than 2**256 - 2**64 (see SEVM.sha3_data()). + However, the smt solver may not be able to solve this condition within the branching timeout. + In such cases, this explicit pattern serves as a fallback to avoid exploring practically infeasible paths. + + We don't need to handle the negation of this condition, because unknown conditions are conservatively assumed to be sat. + """ + + # Not(ule) + if not is_not(cond): + return False + ule = cond.arg(0) + + # Not(ULE(left, right) + if not is_app_of(ule, Z3_OP_ULEQ): + return False + left, right = ule.arg(0), ule.arg(1) + + # Not(ULE(f_sha3_N(slot), offset + base)) + if not (is_f_sha3_name(left.decl().name()) and is_app_of(right, Z3_OP_BADD)): + return False + offset, base = right.arg(0), right.arg(1) + + # Not(ULE(f_sha3_N(slot), offset + f_sha3_N(slot))) and offset < 2**64 + return eq(left, base) and is_bv_value(offset) and offset.as_long() < 2**64 + + def stripped(hexstring: str) -> str: """Remove 0x prefix from hexstring""" return hexstring[2:] if hexstring.startswith("0x") else hexstring diff --git a/tests/expected/all.json b/tests/expected/all.json index 3e9e2fde..056ec17f 100644 --- a/tests/expected/all.json +++ b/tests/expected/all.json @@ -2365,6 +2365,17 @@ "num_bounded_loops": null } ], + "test/Solver.t.sol:SolverTest": [ + { + "name": "check_dynamic_array_overflow()", + "exitcode": 0, + "num_models": 0, + "models": null, + "num_paths": null, + "time": null, + "num_bounded_loops": null + } + ], "test/StaticContexts.t.sol:StaticContextsTest": [ { "name": "check_create2_fails()", diff --git a/tests/regression/test/Solver.t.sol b/tests/regression/test/Solver.t.sol new file mode 100644 index 00000000..776b970a --- /dev/null +++ b/tests/regression/test/Solver.t.sol @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: AGPL-3.0 +pragma solidity >=0.8.0 <0.9.0; + +import "forge-std/Test.sol"; +import {SymTest} from "halmos-cheatcodes/SymTest.sol"; + +contract SolverTest is SymTest, Test { + uint[] numbers; + + function check_dynamic_array_overflow() public { + numbers = new uint[](5); // shouldn't generate loop bounds warning + } +} diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..cdb214ad --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,56 @@ +from z3 import ( + ULE, + BitVec, + BitVecSort, + BitVecVal, + Function, + Not, + simplify, +) + +from halmos.utils import f_sha3_256_name, match_dynamic_array_overflow_condition + + +def test_match_dynamic_array_overflow_condition(): + # Create Z3 objects + f_sha3_256 = Function(f_sha3_256_name, BitVecSort(256), BitVecSort(256)) + slot = BitVec("slot", 256) + offset = BitVecVal(1000, 256) # Less than 2**64 + + # Test the function + cond = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot))) + assert match_dynamic_array_overflow_condition(cond) + + # Test with opposite order of addition + opposite_order_cond = Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset)) + assert not match_dynamic_array_overflow_condition(opposite_order_cond) + + # Test with opposite order after simplification + simplified_opposite_order_cond = simplify( + Not(ULE(f_sha3_256(slot), f_sha3_256(slot) + offset)) + ) + assert match_dynamic_array_overflow_condition(simplified_opposite_order_cond) + + # Test with offset = 2**64 - 1 (should match) + max_valid_offset = BitVecVal(2**64 - 1, 256) + max_valid_cond = Not(ULE(f_sha3_256(slot), max_valid_offset + f_sha3_256(slot))) + assert match_dynamic_array_overflow_condition(max_valid_cond) + + # Test with offset >= 2**64 + large_offset = BitVecVal(2**64, 256) + large_offset_cond = Not(ULE(f_sha3_256(slot), large_offset + f_sha3_256(slot))) + assert not match_dynamic_array_overflow_condition(large_offset_cond) + + # Test with a different function + different_func = Function("different_func", BitVecSort(256), BitVecSort(256)) + non_matching_cond = Not(ULE(different_func(slot), offset + different_func(slot))) + assert not match_dynamic_array_overflow_condition(non_matching_cond) + + # Test with just ULE, not Not(ULE(...)) + ule_only = ULE(f_sha3_256(slot), offset + f_sha3_256(slot)) + assert not match_dynamic_array_overflow_condition(ule_only) + + # Test with mismatched slots + slot2 = BitVec("slot2", 256) + mismatched_slots = Not(ULE(f_sha3_256(slot), offset + f_sha3_256(slot2))) + assert not match_dynamic_array_overflow_condition(mismatched_slots)