Skip to content

Commit

Permalink
fix: assert value is 0 at the end of load_bytecode (kkrt-labs#1290)
Browse files Browse the repository at this point in the history
<!--- Please provide a general summary of your changes in the title
above -->

<!-- Give an estimate of the time you spent on this PR in terms of work
days.
Did you spend 0.5 days on this PR or rather 2 days?  -->

Time spent on this PR:

## Pull request type

<!-- Please try to limit your pull request to one type,
submit multiple pull requests if needed. -->

Please check the type of change your PR introduces:

- [x] Bugfix
- [ ] Feature
- [ ] Code style update (formatting, renaming)
- [ ] Refactoring (no functional changes, no api changes)
- [ ] Build related changes
- [ ] Documentation content changes
- [ ] Other (please describe):

## What is the current behavior?

Do not assert the value is 0 at the end of load_bytecode leading to
prover being able to

Resolves kkrt-labs#1279, resolves kkrt-labs#1293

## What is the new behavior?
assert value is 0 at the end of load_bytecode

<!-- Reviewable:start -->
- - -
This change is [<img src="https://reviewable.io/review_button.svg"
height="34" align="absmiddle"
alt="Reviewable"/>](https://reviewable.io/reviews/kkrt-labs/kakarot/1290)
<!-- Reviewable:end -->

---------

Co-authored-by: Clément Walter <[email protected]>
  • Loading branch information
obatirou and ClementWalter authored Jul 23, 2024
1 parent c366211 commit 29e3930
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 64 deletions.
6 changes: 6 additions & 0 deletions src/kakarot/accounts/library.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -652,11 +652,17 @@ namespace Internals {

jmp cond if remaining_bytes != 0;

with_attr error_message("Value is not empty") {
assert value = 0;
}
let bytecode = cast([fp], felt*);
return (bytecode=bytecode);

cond:
jmp body if count != 0;
with_attr error_message("Value is not empty") {
assert value = 0;
}
jmp read;
}
}
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def seed(request):
pytest_plugins = ["tests.fixtures.starknet"]

settings.register_profile("ci", deadline=None, max_examples=1000)
settings.register_profile("dev", max_examples=10)
settings.register_profile("dev", deadline=None, max_examples=10)
settings.register_profile("debug", max_examples=10, verbosity=Verbosity.verbose)
settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default"))
logger.info(f"Using Hypothesis profile: {os.getenv('HYPOTHESIS_PROFILE', 'default')}")
41 changes: 23 additions & 18 deletions tests/fixtures/starknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,7 @@ def cairo_compile(path):


@pytest.fixture(scope="module")
def cairo_run(request) -> list:
"""
Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs.
Returns the output of the cairo program put in the output memory segment.
When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped.
Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates like the addition of the output segment.
"""
def cairo_program(request) -> list:
cairo_file = Path(request.node.fspath).with_suffix(".cairo")
if not cairo_file.exists():
raise ValueError(f"Missing cairo file: {cairo_file}")
Expand All @@ -89,23 +81,36 @@ def cairo_run(request) -> list:
program = cairo_compile(cairo_file)
stop = perf_counter()
logger.info(f"{cairo_file} compiled in {stop - start:.2f}s")
return program


@pytest.fixture(scope="module")
def cairo_run(request, cairo_program) -> list:
"""
Run the cairo program corresponding to the python test file at a given entrypoint with given program inputs as kwargs.
Returns the output of the cairo program put in the output memory segment.
When --profile-cairo is passed, the cairo program is run with the tracer enabled and the resulting trace is dumped.
Logic is mainly taken from starkware.cairo.lang.vm.cairo_run with minor updates like the addition of the output segment.
"""

def _factory(entrypoint, **kwargs) -> list:
implicit_args = list(
program.identifiers.get_by_full_name(
cairo_program.identifiers.get_by_full_name(
ScopedName(path=["__main__", entrypoint, "ImplicitArgs"])
).members.keys()
)
args = list(
program.identifiers.get_by_full_name(
cairo_program.identifiers.get_by_full_name(
ScopedName(path=["__main__", entrypoint, "Args"])
).members.keys()
)
return_data = program.identifiers.get_by_full_name(
return_data = cairo_program.identifiers.get_by_full_name(
ScopedName(path=["__main__", entrypoint, "Return"])
)
# Fix builtins runner based on the implicit args since the compiler doesn't find them
program.builtins = [
cairo_program.builtins = [
builtin
# This list is extracted from the builtin runners
# Builtins have to be declared in this order
Expand All @@ -125,7 +130,7 @@ def _factory(entrypoint, **kwargs) -> list:

memory = MemoryDict()
runner = CairoRunner(
program=program,
program=cairo_program,
layout=request.config.getoption("layout"),
memory=memory,
proof_mode=request.config.getoption("proof_mode"),
Expand Down Expand Up @@ -165,7 +170,7 @@ def _factory(entrypoint, **kwargs) -> list:
runner.execution_public_memory = list(range(len(stack)))

runner.initialize_state(
entrypoint=program.identifiers.get_by_full_name(
entrypoint=cairo_program.identifiers.get_by_full_name(
ScopedName(path=["__main__", entrypoint])
).pc,
stack=stack,
Expand All @@ -178,7 +183,7 @@ def _factory(entrypoint, **kwargs) -> list:
"syscall_handler": SyscallHandler(),
},
static_locals={
"debug_info": debug_info(program),
"debug_info": debug_info(cairo_program),
"serde": serde,
"Opcodes": Opcodes,
},
Expand Down Expand Up @@ -234,7 +239,7 @@ def _factory(entrypoint, **kwargs) -> list:
)
if request.config.getoption("profile_cairo"):
tracer_data = TracerData(
program=program,
program=cairo_program,
memory=runner.relocated_memory,
trace=runner.relocated_trace,
debug_info=runner.get_relocated_debug_info(),
Expand All @@ -253,7 +258,7 @@ def _factory(entrypoint, **kwargs) -> list:
write_binary_memory(
fp,
runner.relocated_memory,
math.ceil(program.prime.bit_length() / 8),
math.ceil(cairo_program.prime.bit_length() / 8),
)

rc_min, rc_max = runner.get_perm_range_check_limits()
Expand Down
114 changes: 69 additions & 45 deletions tests/src/kakarot/accounts/test_account_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import rlp
from eth_account.account import Account
from eth_utils import keccak
from hypothesis import given, settings
from hypothesis.strategies import binary
from starkware.starknet.public.abi import (
get_selector_from_name,
get_storage_var_address,
Expand All @@ -15,6 +17,7 @@
from tests.utils.constants import CHAIN_ID, TRANSACTION_GAS_LIMIT, TRANSACTIONS
from tests.utils.errors import cairo_error
from tests.utils.helpers import generate_random_private_key, rlp_encode_signed_data
from tests.utils.hints import patch_hint
from tests.utils.syscall_handler import SyscallHandler
from tests.utils.uint256 import int_to_uint256

Expand Down Expand Up @@ -92,7 +95,7 @@ def test_should_write_bytecode(self, cairo_run, bytecode):
SyscallHandler.mock_storage.assert_has_calls(calls)

class TestBytecode:
@pytest.fixture

def storage(self, bytecode):
chunks = wrap(bytecode.hex(), 2 * 31)

Expand All @@ -105,9 +108,9 @@ def _storage(address):

return _storage

def test_should_read_bytecode(self, cairo_run, bytecode, storage):
def test_should_read_bytecode(self, cairo_run, bytecode):
with patch.object(
SyscallHandler, "mock_storage", side_effect=storage
SyscallHandler, "mock_storage", side_effect=self.storage(bytecode)
) as mock_storage:
output_len, output = cairo_run("test__bytecode")
chunk_counts, remainder = divmod(len(bytecode), 31)
Expand All @@ -116,6 +119,24 @@ def test_should_read_bytecode(self, cairo_run, bytecode, storage):
mock_storage.assert_has_calls(calls)
assert output[:output_len] == list(bytecode)

@given(bytecode=binary(min_size=1, max_size=400))
@settings(max_examples=5)
def test_should_raise_when_read_bytecode_zellic_issue_1279(
self, cairo_program, cairo_run, bytecode
):
with (
patch_hint(
cairo_program,
"memory[ids.output] = res = (int(ids.value) % PRIME) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'",
"memory[ids.output] = res = (int(ids.value) % PRIME + 1) % ids.base\nassert res < ids.bound, f'split_int(): Limb {res} is out of range.'",
),
patch.object(
SyscallHandler, "mock_storage", side_effect=self.storage(bytecode)
),
):
with cairo_error(message="Value is not empty"):
output_len, output = cairo_run("test__bytecode")

class TestNonce:
@SyscallHandler.patch("Ownable_owner", 0xDEAD)
def test_should_assert_only_owner(self, cairo_run):
Expand Down Expand Up @@ -367,8 +388,9 @@ def test_should_raise_invalid_signature_for_invalid_chain_id_when_tx_type0_not_p
signed.v,
]

with cairo_error(message="Invalid signature."), SyscallHandler.patch(
"Account_evm_address", address
with (
cairo_error(message="Invalid signature."),
SyscallHandler.patch("Account_evm_address", address),
):
cairo_run(
"test__execute_from_outside",
Expand Down Expand Up @@ -403,8 +425,9 @@ def test_should_raise_invalid_chain_id_tx_type_different_from_0(
signed.v,
]

with cairo_error(message="Invalid chain id"), SyscallHandler.patch(
"Account_evm_address", address
with (
cairo_error(message="Invalid chain id"),
SyscallHandler.patch("Account_evm_address", address),
):
cairo_run(
"test__execute_from_outside",
Expand All @@ -429,8 +452,9 @@ def test_should_raise_invalid_nonce(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_data = list(encoded_unsigned_tx)

with cairo_error(message="Invalid nonce"), SyscallHandler.patch(
"Account_evm_address", address
with (
cairo_error(message="Invalid nonce"),
SyscallHandler.patch("Account_evm_address", address),
):
cairo_run(
"test__execute_from_outside",
Expand All @@ -454,12 +478,10 @@ def test_raise_not_enough_ETH_balance(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_data = list(encoded_unsigned_tx)

with cairo_error(
message="Not enough ETH to pay msg.value + max gas fees"
), SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch(
"Account_nonce", transaction.get("nonce", 0)
with (
cairo_error(message="Not enough ETH to pay msg.value + max gas fees"),
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)),
):
cairo_run(
"test__execute_from_outside",
Expand All @@ -486,12 +508,10 @@ def test_raise_transaction_gas_limit_too_high(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_data = list(encoded_unsigned_tx)

with cairo_error(
message="Transaction gas_limit > Block gas_limit"
), SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch(
"Account_nonce", transaction.get("nonce", 0)
with (
cairo_error(message="Transaction gas_limit > Block gas_limit"),
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)),
):
cairo_run(
"test__execute_from_outside",
Expand Down Expand Up @@ -523,9 +543,11 @@ def test_raise_max_fee_per_gas_too_low(self, cairo_run, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_data = list(encoded_unsigned_tx)

with cairo_error(message="Max fee per gas too low"), SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)):
with (
cairo_error(message="Max fee per gas too low"),
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)),
):
cairo_run(
"test__execute_from_outside",
tx_data=tx_data,
Expand Down Expand Up @@ -565,12 +587,10 @@ def test_raise_max_priority_fee_too_high(self, cairo_run):
signed.v,
]

with cairo_error(
message="Max priority fee greater than max fee per gas"
), SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch(
"Account_nonce", transaction["nonce"]
with (
cairo_error(message="Max priority fee greater than max fee per gas"),
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction["nonce"]),
):
cairo_run(
"test__execute_from_outside",
Expand Down Expand Up @@ -606,15 +626,17 @@ def test_pass_authorized_pre_eip155_transaction(self, cairo_run):
int.from_bytes(keccak(encoded_unsigned_tx), "big")
)

with SyscallHandler.patch(
"Account_evm_address", int(ARACHNID_PROXY_DEPLOYER, 16)
), SyscallHandler.patch(
"Account_authorized_message_hashes",
tx_hash_low,
tx_hash_high,
0x1,
), SyscallHandler.patch(
"Account_nonce", 0
with (
SyscallHandler.patch(
"Account_evm_address", int(ARACHNID_PROXY_DEPLOYER, 16)
),
SyscallHandler.patch(
"Account_authorized_message_hashes",
tx_hash_low,
tx_hash_high,
0x1,
),
SyscallHandler.patch("Account_nonce", 0),
):
output_len, output = cairo_run(
"test__execute_from_outside",
Expand Down Expand Up @@ -661,9 +683,10 @@ def test_pass_all_transactions_types(self, cairo_run, seed, transaction):
encoded_unsigned_tx = rlp_encode_signed_data(transaction)
tx_data = list(encoded_unsigned_tx)

with SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)):
with (
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)),
):
output_len, output = cairo_run(
"test__execute_from_outside",
tx_data=tx_data,
Expand Down Expand Up @@ -713,9 +736,10 @@ def test_should_pass_all_data_len(self, cairo_run, bytecode):
signed.v,
]

with SyscallHandler.patch(
"Account_evm_address", address
), SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)):
with (
SyscallHandler.patch("Account_evm_address", address),
SyscallHandler.patch("Account_nonce", transaction.get("nonce", 0)),
):
output_len, output = cairo_run(
"test__execute_from_outside",
tx_data=tx_data,
Expand Down
24 changes: 24 additions & 0 deletions tests/utils/hints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import defaultdict
from contextlib import contextmanager
from unittest.mock import patch

from starkware.cairo.common.dict import DictTracker
from starkware.cairo.lang.compiler.program import CairoHint


def debug_info(program):
Expand Down Expand Up @@ -45,3 +48,24 @@ def new_default_dict(
current_ptr=base,
)
return base


@contextmanager
def patch_hint(program, hint, new_hint):
patched_hints = {
k: [
(
hint_
if hint_.code != hint
else CairoHint(
accessible_scopes=hint_.accessible_scopes,
flow_tracking_data=hint_.flow_tracking_data,
code=new_hint,
)
)
for hint_ in v
]
for k, v in program.hints.items()
}
with patch.object(program, "hints", new=patched_hints):
yield

0 comments on commit 29e3930

Please sign in to comment.