Skip to content

Commit

Permalink
fix: decoding storage mapping with bytes key (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Nov 29, 2023
1 parent 96061cf commit 7761573
Show file tree
Hide file tree
Showing 7 changed files with 382 additions and 45 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/test-external.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ jobs:
dir: "cicada"
cmd: "halmos --contract LibPrimeTest --function testProve --loop 256"
branch: ""
# - repo: "farcasterxyz/contracts"
# dir: "farcaster-contracts"
# cmd: "halmos"
# branch: ""
- repo: "farcasterxyz/contracts"
dir: "farcaster-contracts"
cmd: "halmos"
branch: ""
- repo: "zobront/halmos-solady"
dir: "halmos-solady"
cmd: "halmos --function testCheck"
Expand Down Expand Up @@ -67,5 +67,5 @@ jobs:
run: python -m pip install -e ./halmos

- name: Test external repo
run: ${{ matrix.cmd }} -v -st --error-unknown --solver-timeout-assertion 0
run: ${{ matrix.cmd }} -v -st --error-unknown --solver-timeout-assertion 0 --solver-threads 2
working-directory: ${{ matrix.dir }}
2 changes: 1 addition & 1 deletion .github/workflows/test-long.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ jobs:
run: pip install -e .

- name: Run pytest
run: pytest -v tests/test_halmos.py -k ${{ matrix.testname }} --halmos-options="-v -st --error-unknown --solver-timeout-assertion 0"
run: pytest -x -v tests/test_halmos.py -k ${{ matrix.testname }} --halmos-options="-v -st --error-unknown --solver-timeout-assertion 0 --solver-threads 2" -s --log-cli-level=
2 changes: 1 addition & 1 deletion src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def future_callback(future_model):

if len(future_models) > 0 and args.verbose >= 1:
print(
f"# of potential paths involving assertion violations: {len(future_models)} / {len(result_exs)}"
f"# of potential paths involving assertion violations: {len(future_models)} / {len(result_exs)} (--solver-threads {args.solver_threads})"
)

if args.early_exit:
Expand Down
150 changes: 112 additions & 38 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,51 +998,93 @@ class Storage:
@classmethod
def normalize(cls, expr: Any) -> Any:
# Concat(Extract(255, 8, bvadd(x, y)), bvadd(Extract(7, 0, x), Extract(7, 0, y))) => x + y
if expr.decl().name() == "concat" and expr.num_args() == 2:
arg0 = expr.arg(0) # Extract(255, 8, bvadd(x, y))
arg1 = expr.arg(1) # bvadd(Extract(7, 0, x), Extract(7, 0, y))
def normalize_extract(arg0, arg1):
if (
arg0.decl().name() == "extract"
and arg0.num_args() == 1
and arg0.params() == [255, 8]
):
arg00 = arg0.arg(0) # bvadd(x, y)
if arg00.decl().name() == "bvadd":
x = arg00.arg(0)
y = arg00.arg(1)
if arg1.decl().name() == "bvadd" and arg1.num_args() == 2:
if eq(arg1.arg(0), simplify(Extract(7, 0, x))) and eq(
arg1.arg(1), simplify(Extract(7, 0, y))
):
return x + y
target = arg0.arg(0) # bvadd(x, y)

# this form triggers the partial inward-propagation of extracts in simplify()
# that is, `Extract(7, 0, bvadd(x, y))` => `bvadd(Extract(7, 0, x), Extract(7, 0, y))`, followed by further simplification
target_equivalent = Concat(
Extract(255, 8, target), Extract(7, 0, target)
)

given = Concat(arg0, arg1)

# since target_equivalent and given may not be structurally equal, we compare their fully simplified forms
if eq(simplify(given), simplify(target_equivalent)):
# here we have: given == target_equivalent == target
return target

return None

if expr.decl().name() == "concat" and expr.num_args() >= 2:
new_args = []

i = 0
n = expr.num_args()

# apply normalize_extract for each pair of adjacent arguments
while i < n - 1:
arg0 = expr.arg(i)
arg1 = expr.arg(i + 1)

arg0_arg1 = normalize_extract(arg0, arg1)

if arg0_arg1 is None: # not simplified
new_args.append(arg0)
i += 1
else: # simplified into a single term
new_args.append(arg0_arg1)
i += 2

# handle the last element
if i == n - 1:
new_args.append(expr.arg(i))

return concat(new_args)

return expr


class SolidityStorage(Storage):
@classmethod
def empty(cls, addr: BitVecRef, slot: int, len_keys: int) -> ArrayRef:
def empty(cls, addr: BitVecRef, slot: int, keys: Tuple) -> ArrayRef:
num_keys = len(keys)
size_keys = cls.bitsize(keys)
return Array(
f"storage_{id_str(addr)}_{slot}_{len_keys}_00",
BitVecSorts[len_keys * 256],
f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_00",
BitVecSorts[size_keys],
BitVecSort256,
)

@classmethod
def init(cls, ex: Exec, addr: Any, slot: int, keys) -> None:
def init(cls, ex: Exec, addr: Any, slot: int, keys: Tuple) -> None:
assert_address(addr)
num_keys = len(keys)
size_keys = cls.bitsize(keys)
if slot not in ex.storage[addr]:
ex.storage[addr][slot] = {}
if len(keys) not in ex.storage[addr][slot]:
if len(keys) == 0:
if num_keys not in ex.storage[addr][slot]:
ex.storage[addr][slot][num_keys] = {}
if size_keys not in ex.storage[addr][slot][num_keys]:
if size_keys == 0:
if ex.symbolic:
label = f"storage_{id_str(addr)}_{slot}_{len(keys)}_00"
ex.storage[addr][slot][len(keys)] = BitVec(label, BitVecSort256)
label = f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_00"
ex.storage[addr][slot][num_keys][size_keys] = BitVec(
label, BitVecSort256
)
else:
ex.storage[addr][slot][len(keys)] = con(0)
ex.storage[addr][slot][num_keys][size_keys] = con(0)
else:
# do not use z3 const array `K(BitVecSort(len(keys)*256), con(0))` when not ex.symbolic
# do not use z3 const array `K(BitVecSort(size_keys), con(0))` when not ex.symbolic
# instead use normal smt array, and generate emptyness axiom; see load()
ex.storage[addr][slot][len(keys)] = cls.empty(addr, slot, len(keys))
ex.storage[addr][slot][num_keys][size_keys] = cls.empty(
addr, slot, keys
)

@classmethod
def load(cls, ex: Exec, addr: Any, loc: Word) -> Word:
Expand All @@ -1051,16 +1093,18 @@ def load(cls, ex: Exec, addr: Any, loc: Word) -> Word:
raise ValueError(offsets)
slot, keys = int_of(offsets[0], "symbolic storage base slot"), offsets[1:]
cls.init(ex, addr, slot, keys)
if len(keys) == 0:
return ex.storage[addr][slot][0]
num_keys = len(keys)
size_keys = cls.bitsize(keys)
if num_keys == 0:
return ex.storage[addr][slot][num_keys][size_keys]
else:
if not ex.symbolic:
# generate emptyness axiom for each array index, instead of using quantified formula; see init()
ex.path.append(
Select(cls.empty(addr, slot, len(keys)), concat(keys)) == con(0)
Select(cls.empty(addr, slot, keys), concat(keys)) == con(0)
)
return ex.select(
ex.storage[addr][slot][len(keys)], concat(keys), ex.storages
ex.storage[addr][slot][num_keys][size_keys], concat(keys), ex.storages
)

@classmethod
Expand All @@ -1070,30 +1114,44 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:
raise ValueError(offsets)
slot, keys = int_of(offsets[0], "symbolic storage base slot"), offsets[1:]
cls.init(ex, addr, slot, keys)
if len(keys) == 0:
ex.storage[addr][slot][0] = val
num_keys = len(keys)
size_keys = cls.bitsize(keys)
if num_keys == 0:
ex.storage[addr][slot][num_keys][size_keys] = val
else:
new_storage_var = Array(
f"storage_{id_str(addr)}_{slot}_{len(keys)}_{1+len(ex.storages):>02}",
BitVecSorts[len(keys) * 256],
f"storage_{id_str(addr)}_{slot}_{num_keys}_{size_keys}_{1+len(ex.storages):>02}",
BitVecSorts[size_keys],
BitVecSort256,
)
new_storage = Store(ex.storage[addr][slot][len(keys)], concat(keys), val)
new_storage = Store(
ex.storage[addr][slot][num_keys][size_keys], concat(keys), val
)
ex.path.append(new_storage_var == new_storage)
ex.storage[addr][slot][len(keys)] = new_storage_var
ex.storage[addr][slot][num_keys][size_keys] = new_storage_var
ex.storages[new_storage_var] = new_storage

@classmethod
def decode(cls, loc: Any) -> Any:
loc = cls.normalize(loc)
if loc.decl().name() == "sha3_512": # m[k] : hash(k.m)
# m[k] : hash(k.m)
if loc.decl().name() == "sha3_512":
args = loc.arg(0)
offset = simplify(Extract(511, 256, args))
base = simplify(Extract(255, 0, args))
return cls.decode(base) + (offset, con(0))
elif loc.decl().name() == "sha3_256": # a[i] : hash(a)+i
# a[i] : hash(a) + i
elif loc.decl().name() == "sha3_256":
base = loc.arg(0)
return cls.decode(base) + (con(0),)
# m[k] : hash(k.m) where |k| != 256-bit
elif loc.decl().name().startswith("sha3_"):
sha3_input = cls.normalize(loc.arg(0))
if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2:
offset = simplify(sha3_input.arg(0))
base = simplify(sha3_input.arg(1))
if offset.size() != 256 and base.size() == 256:
return cls.decode(base) + (offset, con(0))
elif loc.decl().name() == "bvadd":
# # when len(args) == 2
# arg0 = cls.decode(loc.arg(0))
Expand Down Expand Up @@ -1123,11 +1181,19 @@ def decode(cls, loc: Any) -> Any:
return (con(preimage), con(delta))
else:
return (loc,)
elif is_bv(loc):

if is_bv(loc):
return (loc,)
else:
raise ValueError(loc)

@classmethod
def bitsize(cls, keys: Tuple) -> int:
size = sum([key.size() for key in keys])
if len(keys) > 0 and size == 0:
raise ValueError(keys)
return size


class GenericStorage(Storage):
@classmethod
Expand Down Expand Up @@ -1176,7 +1242,14 @@ def decode(cls, loc: Any) -> Any:
lo = cls.decode(simplify(Extract(255, 0, args)))
return cls.simple_hash(Concat(hi, lo))
elif loc.decl().name().startswith("sha3_"):
return cls.simple_hash(cls.decode(loc.arg(0)))
sha3_input = cls.normalize(loc.arg(0))
if sha3_input.decl().name() == "concat":
decoded_sha3_input_args = [
cls.decode(sha3_input.arg(i)) for i in range(sha3_input.num_args())
]
return cls.simple_hash(concat(decoded_sha3_input_args))
else:
return cls.simple_hash(cls.decode(sha3_input))
elif loc.decl().name() == "bvadd":
args = loc.children()
if len(args) < 2:
Expand All @@ -1188,7 +1261,8 @@ def decode(cls, loc: Any) -> Any:
return cls.add_all([cls.simple_hash(con(preimage)), con(delta)])
else:
return loc
elif is_bv(loc):

if is_bv(loc):
return loc
else:
raise ValueError(loc)
Expand Down
31 changes: 31 additions & 0 deletions tests/expected/all.json
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,37 @@
"num_bounded_loops": null
}
],
"test/Storage3.t.sol:Storage3Test": [
{
"name": "check_set((bytes1,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256,uint256))",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/Storage4.t.sol:Storage4Test": [
{
"name": "check_add_1(uint256)",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
},
{
"name": "check_add_2(uint256)",
"exitcode": 0,
"num_models": 0,
"models": null,
"num_paths": null,
"time": null,
"num_bounded_loops": null
}
],
"test/Store.t.sol:StoreTest": [
{
"name": "check_store_Array(uint256,uint256)",
Expand Down
Loading

0 comments on commit 7761573

Please sign in to comment.