diff --git a/src/ethereum_test_base_types/composite_types.py b/src/ethereum_test_base_types/composite_types.py index dbba9896fe..8d689c63e0 100644 --- a/src/ethereum_test_base_types/composite_types.py +++ b/src/ethereum_test_base_types/composite_types.py @@ -1,6 +1,7 @@ """ Base composite types for Ethereum test cases. """ + from dataclasses import dataclass from typing import Any, ClassVar, Dict, List, SupportsBytes, Type, TypeAlias @@ -24,6 +25,7 @@ class Storage(RootModel[Dict[StorageKeyValueType, StorageKeyValueType]]): root: Dict[StorageKeyValueType, StorageKeyValueType] = Field(default_factory=dict) _current_slot: int = PrivateAttr(0) + _hint_map: Dict[StorageKeyValueType, str] = PrivateAttr(default_factory=dict) StorageDictType: ClassVar[TypeAlias] = Dict[ str | int | bytes | SupportsBytes, str | int | bytes | SupportsBytes @@ -92,13 +94,15 @@ class KeyValueMismatch(Exception): key: int want: int got: int + hint: str - def __init__(self, address: Address, key: int, want: int, got: int, *args): + def __init__(self, address: Address, key: int, want: int, got: int, hint: str = "", *args): super().__init__(args) self.address = address self.key = key self.want = want self.got = got + self.hint = hint def __str__(self): """Print exception string""" @@ -107,7 +111,7 @@ def __str__(self): label_str = f" ({self.address.label})" return ( f"incorrect value in address {self.address}{label_str} for " - + f"key {Hash(self.key)}:" + + f"key {Hash(self.key)}{f' ({self.hint})' if self.hint else ''}:" + f" want {HexNumber(self.want)} (dec:{int(self.want)})," + f" got {HexNumber(self.got)} (dec:{int(self.got)})" ) @@ -182,7 +186,7 @@ def items(self): return self.root.items() def store_next( - self, value: StorageKeyValueTypeConvertible | StorageKeyValueType | bool + self, value: StorageKeyValueTypeConvertible | StorageKeyValueType | bool, hint: str = "" ) -> StorageKeyValueType: """ Stores a value in the storage and returns the key where the value is stored. @@ -192,6 +196,8 @@ def store_next( """ slot = StorageKeyValueTypeAdapter.validate_python(self._current_slot) self._current_slot += 1 + if hint: + self._hint_map[slot] = hint self[slot] = StorageKeyValueTypeAdapter.validate_python(value) return slot @@ -230,7 +236,11 @@ def must_contain(self, address: Address, other: "Storage"): raise Storage.MissingKey(key=key) elif self[key] != other[key]: raise Storage.KeyValueMismatch( - address=address, key=key, want=self[key], got=other[key] + address=address, + key=key, + want=self[key], + got=other[key], + hint=self._hint_map.get(key, ""), ) def must_be_equal(self, address: Address, other: "Storage | None"): @@ -243,17 +253,33 @@ def must_be_equal(self, address: Address, other: "Storage | None"): for key in self.keys() & other.keys(): if self[key] != other[key]: raise Storage.KeyValueMismatch( - address=address, key=key, want=self[key], got=other[key] + address=address, + key=key, + want=self[key], + got=other[key], + hint=self._hint_map.get(key, ""), ) # Test keys contained in either one of the storage objects for key in self.keys() ^ other.keys(): if key in self: if self[key] != 0: - raise Storage.KeyValueMismatch(address=address, key=key, want=self[key], got=0) + raise Storage.KeyValueMismatch( + address=address, + key=key, + want=self[key], + got=0, + hint=self._hint_map.get(key, ""), + ) elif other[key] != 0: - raise Storage.KeyValueMismatch(address=address, key=key, want=0, got=other[key]) + raise Storage.KeyValueMismatch( + address=address, + key=key, + want=0, + got=other[key], + hint=self._hint_map.get(key, ""), + ) def canary(self) -> "Storage": """