Skip to content

Commit

Permalink
chore(tests): update G1/add G2 msm discount tables ethereum/EIPs#9116
Browse files Browse the repository at this point in the history
  • Loading branch information
danceratopz committed Dec 18, 2024
1 parent c1efe8a commit 0de9f33
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 43 deletions.

This file was deleted.

71 changes: 52 additions & 19 deletions tests/prague/eip2537_bls_12_381_precompiles/spec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
Defines EIP-2537 specification constants and functions.
"""
import json
from dataclasses import dataclass
from typing import Callable, List, Sized, SupportsBytes, Tuple

from .helpers import current_python_script_directory
from enum import Enum, auto
from typing import Callable, Sized, SupportsBytes, Tuple


@dataclass(frozen=True)
Expand Down Expand Up @@ -109,11 +107,6 @@ def __bytes__(self) -> bytes:
return self.x.to_bytes(32, byteorder="big")


with open(current_python_script_directory("msm_discount_table.json")) as f:
MSM_DISCOUNT_TABLE: List[int] = json.load(f)
assert type(MSM_DISCOUNT_TABLE) is list


@dataclass(frozen=True)
class Spec:
"""
Expand Down Expand Up @@ -149,7 +142,30 @@ class Spec:
P = (X - 1) ** 2 * Q // 3 + X
LEN_PER_PAIR = len(PointG1() + PointG2())
MSM_MULTIPLIER = 1_000
MSM_DISCOUNT_TABLE = MSM_DISCOUNT_TABLE
# fmt: off
G1MSM_DISCOUNT_TABLE = [
0,
1000, 949, 848, 797, 764, 750, 738, 728, 719, 712, 705, 698, 692, 687, 682, 677, 673, 669,
665, 661, 658, 654, 651, 648, 645, 642, 640, 637, 635, 632, 630, 627, 625, 623, 621, 619,
617, 615, 613, 611, 609, 608, 606, 604, 603, 601, 599, 598, 596, 595, 593, 592, 591, 589,
588, 586, 585, 584, 582, 581, 580, 579, 577, 576, 575, 574, 573, 572, 570, 569, 568, 567,
566, 565, 564, 563, 562, 561, 560, 559, 558, 557, 556, 555, 554, 553, 552, 551, 550, 549,
548, 547, 547, 546, 545, 544, 543, 542, 541, 540, 540, 539, 538, 537, 536, 536, 535, 534,
533, 532, 532, 531, 530, 529, 528, 528, 527, 526, 525, 525, 524, 523, 522, 522, 521, 520,
520, 519
]
G2MSM_DISCOUNT_TABLE = [
0,
1000, 1000, 923, 884, 855, 832, 812, 796, 782, 770, 759, 749, 740, 732, 724, 717, 711, 704,
699, 693, 688, 683, 679, 674, 670, 666, 663, 659, 655, 652, 649, 646, 643, 640, 637, 634,
632, 629, 627, 624, 622, 620, 618, 615, 613, 611, 609, 607, 606, 604, 602, 600, 598, 597,
595, 593, 592, 590, 589, 587, 586, 584, 583, 582, 580, 579, 578, 576, 575, 574, 573, 571,
570, 569, 568, 567, 566, 565, 563, 562, 561, 560, 559, 558, 557, 556, 555, 554, 553, 552,
552, 551, 550, 549, 548, 547, 546, 545, 545, 544, 543, 542, 541, 541, 540, 539, 538, 537,
537, 536, 535, 535, 534, 533, 532, 532, 531, 530, 530, 529, 528, 528, 527, 526, 526, 525,
524, 524
]
# fmt: on

# Test constants (from https://github.com/ethereum/bls12-381-tests/tree/eip-2537)
P1 = PointG1( # random point in G1
Expand Down Expand Up @@ -217,17 +233,34 @@ class Spec:
INVALID = b""


def msm_discount(k: int) -> int:
class BLS12Group(Enum):
"""
Returns the discount for the G1MSM and G2MSM precompiles.
Helper enum to specify the BLS12 group in discount table helpers.
"""
return Spec.MSM_DISCOUNT_TABLE[min(k, 128)]

G1 = auto()
G2 = auto()

def msm_gas_func_gen(len_per_pair: int, multiplication_cost: int) -> Callable[[int], int]:

def msm_discount(group: BLS12Group, k: int) -> int:
"""
Returns the discount for the G1MSM and G2MSM precompiles.
"""
assert k >= 1, "k must be greater than or equal to 1"
match group:
case BLS12Group.G1:
return Spec.G1MSM_DISCOUNT_TABLE[min(k, 128)]
case BLS12Group.G2:
return Spec.G2MSM_DISCOUNT_TABLE[min(k, 128)]
case _:
raise ValueError(f"Unsupported group: {group}")


def msm_gas_func_gen(
group: BLS12Group, len_per_pair: int, multiplication_cost: int
) -> Callable[[int], int]:
"""
Generates a function that calculates the gas cost for the G1MSM and G2MSM
precompiles.
Generate a function that calculates the gas cost for the G1MSM and G2MSM precompiles.
"""

def msm_gas(input_length: int) -> int:
Expand All @@ -238,7 +271,7 @@ def msm_gas(input_length: int) -> int:
if k == 0:
return 0

gas_cost = k * multiplication_cost * msm_discount(k) // Spec.MSM_MULTIPLIER
gas_cost = k * multiplication_cost * msm_discount(group, k) // Spec.MSM_MULTIPLIER

return gas_cost

Expand All @@ -256,10 +289,10 @@ def pairing_gas(input_length: int) -> int:
GAS_CALCULATION_FUNCTION_MAP = {
Spec.G1ADD: lambda _: Spec.G1ADD_GAS,
Spec.G1MUL: lambda _: Spec.G1MUL_GAS,
Spec.G1MSM: msm_gas_func_gen(len(PointG1() + Scalar()), Spec.G1MUL_GAS),
Spec.G1MSM: msm_gas_func_gen(BLS12Group.G1, len(PointG1() + Scalar()), Spec.G1MUL_GAS),
Spec.G2ADD: lambda _: Spec.G2ADD_GAS,
Spec.G2MUL: lambda _: Spec.G2MUL_GAS,
Spec.G2MSM: msm_gas_func_gen(len(PointG2() + Scalar()), Spec.G2MUL_GAS),
Spec.G2MSM: msm_gas_func_gen(BLS12Group.G2, len(PointG2() + Scalar()), Spec.G2MUL_GAS),
Spec.PAIRING: pairing_gas,
Spec.MAP_FP_TO_G1: lambda _: Spec.MAP_FP_TO_G1_GAS,
Spec.MAP_FP2_TO_G2: lambda _: Spec.MAP_FP2_TO_G2_GAS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
vectors_from_file("multiexp_G1_bls.json")
+ [
pytest.param(
(Spec.P1 + Scalar(Spec.Q)) * (len(Spec.MSM_DISCOUNT_TABLE) - 1),
(Spec.P1 + Scalar(Spec.Q)) * (len(Spec.G1MSM_DISCOUNT_TABLE) - 1),
Spec.INF_G1,
id="max_discount",
),
pytest.param(
(Spec.P1 + Scalar(Spec.Q)) * len(Spec.MSM_DISCOUNT_TABLE),
(Spec.P1 + Scalar(Spec.Q)) * len(Spec.G1MSM_DISCOUNT_TABLE),
Spec.INF_G1,
id="max_discount_plus_1",
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,16 @@ def call_contract_code(
"precompile_gas_list,precompile_data_length_list",
[
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="exact_gas_full_discount_table",
),
pytest.param(
[
G1_GAS(i * G1_MSM_K_INPUT_LENGTH) + 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))
],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="one_extra_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -149,9 +149,9 @@ def test_valid_gas_g1msm(
pytest.param(
[
G1_GAS(i * G1_MSM_K_INPUT_LENGTH) - 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))
],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G1_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="insufficient_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -188,13 +188,13 @@ def test_invalid_gas_g1msm(
id="zero_length_input",
),
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_short_full_discount_table",
),
pytest.param(
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G1_GAS(i * G1_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
[(i * G1_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.G1MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_long_full_discount_table",
),
],
Expand Down Expand Up @@ -226,22 +226,22 @@ def test_invalid_length_g1msm(
"precompile_gas_list,precompile_data_length_list",
[
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="exact_gas_full_discount_table",
),
pytest.param(
[
G2_GAS(i * G2_MSM_K_INPUT_LENGTH) + 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))
],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="one_extra_gas_full_discount_table",
),
],
)
@pytest.mark.parametrize("expected_output", [PointG2()], ids=[""])
@pytest.mark.parametrize("tx_gas_limit", [100_000_000], ids=[""])
@pytest.mark.parametrize("tx_gas_limit", [110_000_000], ids=[""])
@pytest.mark.parametrize("precompile_address", [Spec.G2MSM])
def test_valid_gas_g2msm(
state_test: StateTestFiller,
Expand Down Expand Up @@ -274,9 +274,9 @@ def test_valid_gas_g2msm(
pytest.param(
[
G2_GAS(i * G2_MSM_K_INPUT_LENGTH) - 1
for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))
for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))
],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[i * G2_MSM_K_INPUT_LENGTH for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="insufficient_gas_full_discount_table",
),
],
Expand Down Expand Up @@ -313,13 +313,13 @@ def test_invalid_gas_g2msm(
id="zero_length_input",
),
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) - 1 for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_short_full_discount_table",
),
pytest.param(
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.MSM_DISCOUNT_TABLE))],
[G2_GAS(i * G2_MSM_K_INPUT_LENGTH) for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
[(i * G2_MSM_K_INPUT_LENGTH) + 1 for i in range(1, len(Spec.G2MSM_DISCOUNT_TABLE))],
id="input_one_byte_too_long_full_discount_table",
),
],
Expand Down
2 changes: 2 additions & 0 deletions whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ filelock
filesystem
fillvalue
firstlineno
fmt
fn
fname
forkchoice
Expand Down Expand Up @@ -371,6 +372,7 @@ rpc
ruleset
runtestloop
runtime
S12
sandboxed
secp256k1
secp256k1n
Expand Down

0 comments on commit 0de9f33

Please sign in to comment.