diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 82d991fe..b80ad513 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -253,9 +253,20 @@ def wstore_bytes( raise ValueError(size, arr) wextend(mem, loc, size) for i in range(size): - if not eq(arr[i].sort(), BitVecSort8): - raise ValueError(arr) - mem[loc + i] = arr[i] + val = arr[i] + if not is_byte(val): + raise ValueError(val) + + mem[loc + i] = val + + +def is_byte(x: Any) -> bool: + if is_bv(x): + return eq(x.sort(), BitVecSort8) + elif isinstance(x, int): + return 0 <= x < 256 + else: + return False def normalize(expr: Any) -> Any: diff --git a/src/halmos/utils.py b/src/halmos/utils.py index fd57dab7..9356a7b0 100644 --- a/src/halmos/utils.py +++ b/src/halmos/utils.py @@ -7,7 +7,7 @@ from z3 import * -from .exceptions import NotConcreteError +from .exceptions import NotConcreteError, HalmosException Word = Any # z3 expression (including constants) @@ -148,7 +148,8 @@ def extract_bytes_argument(calldata: BitVecRef, arg_idx: int) -> bytes: if length == 0: return b"" - return bv_value_to_bytes(extract_bytes(calldata, 4 + offset + 32, length)) + bytes = extract_bytes(calldata, 4 + offset + 32, length) + return bv_value_to_bytes(bytes) if is_bv_value(bytes) else bytes def extract_string_argument(calldata: BitVecRef, arg_idx: int): @@ -215,16 +216,16 @@ def int_of(x: Any, err: str = "expected concrete value but got") -> int: raise NotConcreteError(f"{err}: {x}") -def byte_length(x: Any) -> int: +def byte_length(x: Any, strict=True) -> int: if is_bv(x): - if x.size() % 8 != 0: - raise ValueError(x) - return x.size() >> 3 + if x.size() % 8 != 0 and strict: + raise HalmosException(f"byte_length({x}) with bit size {x.size()}") + return math.ceil(x.size() / 8) if isinstance(x, bytes): return len(x) - raise ValueError(x) + raise HalmosException(f"byte_length({x}) of type {type(x)}") def decode_hex(hexstring: str) -> Optional[bytes]: @@ -246,24 +247,32 @@ def hexify(x): return "0x" + x.hex() elif is_bv_value(x): # maintain the byte size of x - num_bytes = byte_length(x) + num_bytes = byte_length(x, strict=False) return f"0x{x.as_long():0{num_bytes * 2}x}" + elif is_app(x): + return f"{str(x.decl())}({', '.join(map(hexify, x.children()))})" else: return hexify(str(x)) def render_uint(x: BitVecRef) -> str: - val = int_of(x) - return f"0x{val:0{byte_length(x) * 2}x} ({val})" + if is_bv_value(x): + val = int_of(x) + return f"0x{val:0{byte_length(x, strict=False) * 2}x} ({val})" + + return hexify(x) def render_int(x: BitVecRef) -> str: - val = x.as_signed_long() - return f"0x{x.as_long():0{byte_length(x) * 2}x} ({val})" + if is_bv_value(x): + val = x.as_signed_long() + return f"0x{x.as_long():0{byte_length(x, strict=False) * 2}x} ({val})" + + return hexify(x) def render_bool(b: BitVecRef) -> str: - return str(b.as_long() != 0).lower() + return str(b.as_long() != 0).lower() if is_bv_value(b) else hexify(b) def render_string(s: BitVecRef) -> str: @@ -273,13 +282,16 @@ def render_string(s: BitVecRef) -> str: def render_bytes(b: UnionType[BitVecRef, bytes]) -> str: if is_bv(b): - return f'hex"{hex(b.as_long())[2:]}"' + return hexify(b) + f" ({byte_length(b, strict=False)} bytes)" else: return f'hex"{b.hex()[2:]}"' def render_address(a: BitVecRef) -> str: - return f"0x{a.as_long():040x}" + if is_bv_value(a): + return f"0x{a.as_long():040x}" + + return hexify(a) def stringify(symbol_name: str, val: Any): diff --git a/tests/test_sevm.py b/tests/test_sevm.py index d19152fc..9828c8d8 100644 --- a/tests/test_sevm.py +++ b/tests/test_sevm.py @@ -23,6 +23,7 @@ iter_bytes, wload, wstore, + wstore_bytes, Path, ) @@ -392,3 +393,19 @@ def test_wload_bad_byte(): with pytest.raises(ValueError): wload([512], 0, 1, prefer_concrete=False) + + +def test_wstore_bytes_concrete(): + mem = [0] * 4 + wstore_bytes(mem, 0, 4, bytes.fromhex("12345678")) + assert mem == [0x12, 0x34, 0x56, 0x78] + + +def test_wstore_bytes_concolic(): + mem1 = [0] * 4 + wstore(mem1, 0, 4, con(0x12345678, 32)) + + mem2 = [0] * 4 + wstore_bytes(mem2, 0, 4, mem1) + + assert mem2 == [0x12, 0x34, 0x56, 0x78]