Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(tests): update eip-2537 bls precompile gas pricing #1032

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

83 changes: 58 additions & 25 deletions tests/prague/eip2537_bls_12_381_precompiles/spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
Defines EIP-2537 specification constants and functions.
"""
import json
from dataclasses import dataclass
from typing import Callable, List, Sized, SupportsBytes, Tuple

from .helpers import current_python_script_directory
from enum import Enum, auto
from typing import Callable, Sized, SupportsBytes, Tuple


@dataclass(frozen=True)
Expand Down Expand Up @@ -109,11 +107,6 @@ def __bytes__(self) -> bytes:
return self.x.to_bytes(32, byteorder="big")


with open(current_python_script_directory("msm_discount_table.json")) as f:
MSM_DISCOUNT_TABLE: List[int] = json.load(f)
assert type(MSM_DISCOUNT_TABLE) is list


@dataclass(frozen=True)
class Spec:
"""
Expand All @@ -133,14 +126,14 @@ class Spec:
MAP_FP2_TO_G2 = 0x13

# Gas constants
G1ADD_GAS = 500
G1ADD_GAS = 375
G1MUL_GAS = 12_000
G2ADD_GAS = 800
G2MUL_GAS = 45_000
G2ADD_GAS = 600
G2MUL_GAS = 22_500
MAP_FP_TO_G1_GAS = 5_500
MAP_FP2_TO_G2_GAS = 75_000
PAIRING_BASE_GAS = 65_000
PAIRING_PER_PAIR_GAS = 43_000
MAP_FP2_TO_G2_GAS = 23_800
PAIRING_BASE_GAS = 37_700
PAIRING_PER_PAIR_GAS = 32_600

# Other constants
B_COEFFICIENT = 0x04
Expand All @@ -149,7 +142,30 @@ class Spec:
P = (X - 1) ** 2 * Q // 3 + X
LEN_PER_PAIR = len(PointG1() + PointG2())
MSM_MULTIPLIER = 1_000
MSM_DISCOUNT_TABLE = MSM_DISCOUNT_TABLE
# fmt: off
G1MSM_DISCOUNT_TABLE = [
0,
1000, 949, 848, 797, 764, 750, 738, 728, 719, 712, 705, 698, 692, 687, 682, 677, 673, 669,
665, 661, 658, 654, 651, 648, 645, 642, 640, 637, 635, 632, 630, 627, 625, 623, 621, 619,
617, 615, 613, 611, 609, 608, 606, 604, 603, 601, 599, 598, 596, 595, 593, 592, 591, 589,
588, 586, 585, 584, 582, 581, 580, 579, 577, 576, 575, 574, 573, 572, 570, 569, 568, 567,
566, 565, 564, 563, 562, 561, 560, 559, 558, 557, 556, 555, 554, 553, 552, 551, 550, 549,
548, 547, 547, 546, 545, 544, 543, 542, 541, 540, 540, 539, 538, 537, 536, 536, 535, 534,
533, 532, 532, 531, 530, 529, 528, 528, 527, 526, 525, 525, 524, 523, 522, 522, 521, 520,
520, 519
]
G2MSM_DISCOUNT_TABLE = [
0,
1000, 1000, 923, 884, 855, 832, 812, 796, 782, 770, 759, 749, 740, 732, 724, 717, 711, 704,
699, 693, 688, 683, 679, 674, 670, 666, 663, 659, 655, 652, 649, 646, 643, 640, 637, 634,
632, 629, 627, 624, 622, 620, 618, 615, 613, 611, 609, 607, 606, 604, 602, 600, 598, 597,
595, 593, 592, 590, 589, 587, 586, 584, 583, 582, 580, 579, 578, 576, 575, 574, 573, 571,
570, 569, 568, 567, 566, 565, 563, 562, 561, 560, 559, 558, 557, 556, 555, 554, 553, 552,
552, 551, 550, 549, 548, 547, 546, 545, 545, 544, 543, 542, 541, 541, 540, 539, 538, 537,
537, 536, 535, 535, 534, 533, 532, 532, 531, 530, 530, 529, 528, 528, 527, 526, 526, 525,
524, 524
]
# fmt: on

# Test constants (from https://github.com/ethereum/bls12-381-tests/tree/eip-2537)
P1 = PointG1( # random point in G1
Expand Down Expand Up @@ -217,17 +233,34 @@ class Spec:
INVALID = b""


def msm_discount(k: int) -> int:
class BLS12Group(Enum):
"""
Returns the discount for the G1MSM and G2MSM precompiles.
Helper enum to specify the BLS12 group in discount table helpers.
"""
return Spec.MSM_DISCOUNT_TABLE[min(k, 128)]

G1 = auto()
G2 = auto()

def msm_gas_func_gen(len_per_pair: int, multiplication_cost: int) -> Callable[[int], int]:

def msm_discount(group: BLS12Group, k: int) -> int:
"""
Returns the discount for the G1MSM and G2MSM precompiles.
"""
assert k >= 1, "k must be greater than or equal to 1"
match group:
case BLS12Group.G1:
return Spec.G1MSM_DISCOUNT_TABLE[min(k, 128)]
case BLS12Group.G2:
return Spec.G2MSM_DISCOUNT_TABLE[min(k, 128)]
case _:
raise ValueError(f"Unsupported group: {group}")


def msm_gas_func_gen(
group: BLS12Group, len_per_pair: int, multiplication_cost: int
) -> Callable[[int], int]:
"""
Generates a function that calculates the gas cost for the G1MSM and G2MSM
precompiles.
Generate a function that calculates the gas cost for the G1MSM and G2MSM precompiles.
"""

def msm_gas(input_length: int) -> int:
Expand All @@ -238,7 +271,7 @@ def msm_gas(input_length: int) -> int:
if k == 0:
return 0

gas_cost = k * multiplication_cost * msm_discount(k) // Spec.MSM_MULTIPLIER
gas_cost = k * multiplication_cost * msm_discount(group, k) // Spec.MSM_MULTIPLIER

return gas_cost

Expand All @@ -256,10 +289,10 @@ def pairing_gas(input_length: int) -> int:
GAS_CALCULATION_FUNCTION_MAP = {
Spec.G1ADD: lambda _: Spec.G1ADD_GAS,
Spec.G1MUL: lambda _: Spec.G1MUL_GAS,
Spec.G1MSM: msm_gas_func_gen(len(PointG1() + Scalar()), Spec.G1MUL_GAS),
Spec.G1MSM: msm_gas_func_gen(BLS12Group.G1, len(PointG1() + Scalar()), Spec.G1MUL_GAS),
Spec.G2ADD: lambda _: Spec.G2ADD_GAS,
Spec.G2MUL: lambda _: Spec.G2MUL_GAS,
Spec.G2MSM: msm_gas_func_gen(len(PointG2() + Scalar()), Spec.G2MUL_GAS),
Spec.G2MSM: msm_gas_func_gen(BLS12Group.G2, len(PointG2() + Scalar()), Spec.G2MUL_GAS),
Spec.PAIRING: pairing_gas,
Spec.MAP_FP_TO_G1: lambda _: Spec.MAP_FP_TO_G1_GAS,
Spec.MAP_FP2_TO_G2: lambda _: Spec.MAP_FP2_TO_G2_GAS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
vectors_from_file("multiexp_G1_bls.json")
+ [
pytest.param(
(Spec.P1 + Scalar(Spec.Q)) * (len(Spec.MSM_DISCOUNT_TABLE) - 1),
(Spec.P1 + Scalar(Spec.Q)) * (len(Spec.G1MSM_DISCOUNT_TABLE) - 1),
Spec.INF_G1,
id="max_discount",
),
pytest.param(
(Spec.P1 + Scalar(Spec.Q)) * len(Spec.MSM_DISCOUNT_TABLE),
(Spec.P1 + Scalar(Spec.Q)) * len(Spec.G1MSM_DISCOUNT_TABLE),
Spec.INF_G1,
id="max_discount_plus_1",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ def call_contract_code(
"precompile_gas_list,precompile_data_length_list",
[
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="exact_gas_full_discount_table",
),
pytest.param(
[
G1_GAS(i * G1_MSM_K_INPUT_LENGTH) + 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))
],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="one_extra_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -149,9 +149,9 @@ def test_valid_gas_g1msm(
pytest.param(
[
G1_GAS(i * G1_MSM_K_INPUT_LENGTH) - 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))
],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="insufficient_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -188,13 +188,13 @@ def test_invalid_gas_g1msm(
id="zero_length_input",
),
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_short_full_discount_table",
),
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_long_full_discount_table",
),
],
Expand Down Expand Up @@ -226,22 +226,22 @@ def test_invalid_length_g1msm(
"precompile_gas_list,precompile_data_length_list",
[
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="exact_gas_full_discount_table",
),
pytest.param(
[
G2_GAS(i * G2_MSM_K_INPUT_LENGTH) + 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))
],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="one_extra_gas_full_discount_table",
),
],
)
@pytest.mark.parametrize("expected_output", [PointG2()], ids=[""])
@pytest.mark.parametrize("tx_gas_limit", [100_000_000], ids=[""])
@pytest.mark.parametrize("tx_gas_limit", [110_000_000], ids=[""])
@pytest.mark.parametrize("precompile_address", [Spec.G2MSM])
def test_valid_gas_g2msm(
state_test: StateTestFiller,
Expand Down Expand Up @@ -274,9 +274,9 @@ def test_valid_gas_g2msm(
pytest.param(
[
G2_GAS(i * G2_MSM_K_INPUT_LENGTH) - 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))
],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="insufficient_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -313,13 +313,13 @@ def test_invalid_gas_g2msm(
id="zero_length_input",
),
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_short_full_discount_table",
),
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_long_full_discount_table",
),
],
Expand Down
2 changes: 2 additions & 0 deletions whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ filelock
filesystem
fillvalue
firstlineno
fmt
fn
fname
forkchoice
Expand Down Expand Up @@ -371,6 +372,7 @@ rpc
ruleset
runtestloop
runtime
S12
sandboxed
secp256k1
secp256k1n
Expand Down
Loading