Skip to content

Commit

Permalink
Improve benchmarking script
Browse files Browse the repository at this point in the history
  • Loading branch information
0xddom committed Aug 29, 2023
1 parent 2d4203b commit 2545e71
Showing 1 changed file with 79 additions and 46 deletions.
125 changes: 79 additions & 46 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os.path
import re
import subprocess
import tempfile
import time
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -27,26 +28,48 @@ def setup_dirs(src: str, out: str) -> Tuple[str, str, TemporaryDirectory]:
return src, out, tmp


def check_is_missing_feature(message: str) -> Match[str] | None:
check = re.compile(r"thread 'main' panicked at 'not yet implemented', (.*):(.*):(.*)")
return check.match(message)
def check_is_missing_feature(fd) -> Match[str] | None:
fd.seek(0)
for line in fd:
check = re.compile(r"thread 'main' panicked at 'not yet implemented', (.*):(.*):(.*)")
match = check.match(line.decode("utf-8"))
if match is not None:
return match
return


def check_is_other_panic(message: str) -> Match[str] | None:
check = re.compile(r"thread 'main' panicked at (.*), (.*):(.*):(.*)")
return check.match(message)
def check_is_other_panic(fd) -> Match[str] | None:
fd.seek(0)
for line in fd:
check = re.compile(r"thread 'main' panicked at (.*), (.*):(.*):(.*)")
match = check.match(line.decode("utf-8"))
if match is not None:
return match
return


def is_circom_error(message: str) -> bool:
return re.compile(r"^error\[.*\]:").match(message) is not None
def is_circom_error(fd) -> bool:
fd.seek(0)
for line in fd:
if re.compile(r"^error\[.*\]:").match(line.decode("utf-8")) is not None:
return True
return False


def is_llvm_validation_error(message: str) -> bool:
return "LLVM Module verification failed" in message
def is_llvm_validation_error(fd) -> bool:
fd.seek(0)
for line in fd:
if "LLVM Module verification failed" in line.decode("utf-8"):
return True
return False


def non_constant_id(message: str) -> bool:
return "is_constant_int()" in message
def non_constant_id(fd) -> bool:
fd.seek(0)
for line in fd:
if "is_constant_int()" in line.decode("utf-8"):
return True
return False


class TimedOutExecution:
Expand All @@ -73,12 +96,13 @@ def extract_number_templates(message) -> int:


class Report:
def __init__(self, src: str, cmd: List[str], execution: Union[subprocess.CompletedProcess | TimedOutExecution], run_time: float):
def __init__(self, src: str, cmd: List[str], execution: Union[subprocess.CompletedProcess | TimedOutExecution], run_time: float, stderr):
self.src = src
self.cmd = cmd
self.execution = execution
self.run_time = run_time
self.test_id = None
self._stderr = stderr

@property
def successful(self):
Expand Down Expand Up @@ -107,14 +131,15 @@ def has_panic(self):

@property
def stderr(self):
if self.execution.stderr:
return escape_ansi(self.execution.stderr.decode("utf-8"))
return ""
#if self.execution.stderr:
# return escape_ansi(self.execution.stderr.decode("utf-8"))
#return ""
return self._stderr

@property
def stdout(self):
if self.execution.stdout:
return escape_ansi(self.execution.stdout.decode("utf-8"))
#if self.execution.stdout:
# return escape_ansi(self.execution.stdout.decode("utf-8"))
return ""

@property
Expand All @@ -141,14 +166,12 @@ def to_dict(self) -> dict:
'cmd': self.cmd,
'return_code': self.execution.returncode,
'successful': self.successful,
'stdout': self.stdout,
'stderr': self.stderr,
'run_time': self.run_time,
'missing_feature': self.missing_feature,
'error_class': self.error_class,
'panicked': self.has_panic,
'test_id': self.test_id,
'template_instances': extract_number_templates(self.stdout)
#'template_instances': extract_number_templates(self.stdout)
}


Expand All @@ -157,9 +180,20 @@ def escape_ansi(line: str) -> str:
return ansi_escape.sub('', line)


def tail(f):
execution = subprocess.run(['tail', f.name], capture_output=True)
if execution.returncode == 0:
if execution.stdout:
print(escape_ansi(execution.stdout.decode("utf-8")))
else:
print(execution.stderr.decode("utf-8"))
exit(1)


def run_test(src: str, circom: str, debug: bool, cwd: str, libs_path: Optional[str], timeout: int) -> Report:
src = os.path.realpath(src)
tmp = TemporaryDirectory()
stderr = tempfile.NamedTemporaryFile()
cmd = [
circom,
'--llvm',
Expand All @@ -169,29 +203,28 @@ def run_test(src: str, circom: str, debug: bool, cwd: str, libs_path: Optional[s
if libs_path:
cmd.extend(['-l', libs_path])
cmd.append(src)
if debug:
print("Source file:", src)
print("CMD:", ' '.join(cmd))
print("Source file:", src)
try:
start = time.time()
execution = subprocess.run(cmd, capture_output=True, cwd=cwd, timeout=timeout)
execution = subprocess.run(cmd, stderr=stderr, stdout=subprocess.DEVNULL, cwd=cwd, timeout=timeout)
end = time.time()
if execution.returncode == 0:
print("Success!")
else:
print("Failure!")
print("CMD:", ' '.join(cmd))
if debug:
if execution.returncode == 0:
print("Success!")
else:
print("Failure!")
# if execution.stdout:
# print("Circom stdout:\n", escape_ansi(execution.stdout.decode("utf-8")))
if execution.stderr:
print("Circom stderr:\n", execution.stderr.decode("utf-8"))
print("Circom stderr:\n")
tail(stderr)
print("Execution time in seconds:", end - start)

return Report(src, cmd, execution, end - start)
return Report(src, cmd, execution, end - start, stderr)
except subprocess.TimeoutExpired as e:
if debug:
print("Test timed out!")
return Report(src, cmd, TimedOutExecution(e), timeout)
print("Test timed out!")
print("CMD:", ' '.join(cmd))
return Report(src, cmd, TimedOutExecution(e), timeout, stderr)


def check_link_libraries(data: dict) -> bool:
Expand All @@ -207,8 +240,7 @@ def run_setup(data: dict):
os.system(cmd)


def evaluate_test(test_path: str, circom: str, debug: bool, libs_path: str, timeout: int) -> List[Report]:
reports = []
def evaluate_test(test_path: str, circom: str, debug: bool, libs_path: str, timeout: int):
with open(test_path) as f:
test_data = json.load(f)
test_cwd = Path(test_path).parent
Expand All @@ -222,9 +254,11 @@ def evaluate_test(test_path: str, circom: str, debug: bool, libs_path: str, time
main_circom_file = test_cwd.joinpath(test['main'])
report = run_test(str(main_circom_file), circom, debug, str(test_cwd), libs_path, timeout)
report.test_id = f"{test_data['id']}_{n}"
reports.append(report)
return reports
yield report

def get_reports(tests, circom, debug, src, timeout):
for test in tests:
yield from evaluate_test(test, circom, debug, str(src.joinpath("tests/libs")), timeout)

@click.command()
@click.option('--src', help='Path where the benchmark is located.')
Expand All @@ -233,28 +267,27 @@ def evaluate_test(test_path: str, circom: str, debug: bool, libs_path: str, time
@click.option('--debug', help="Print debug information", is_flag=True)
@click.option('--timeout', help="Timeout for stopping the compilation", default=600)
def main(src, out, circom, debug, timeout):
reports = []
src = Path(src)
tests = glob.glob(str(src.joinpath(GLOB)), recursive=True)
for test in tests:
reports.extend(evaluate_test(test, circom, debug, str(src.joinpath("tests/libs")), timeout))
reports = get_reports(tests, circom, debug, src, timeout)

with open(out, 'w') as out_csv:
print('test_id,successful,error_class,message,file,line,column,run_time,template_instances', file=out_csv)
print('test_id,successful,error_class,message,file,line,column,run_time', file=out_csv)
for report in reports:
report = report.to_dict()
if report['successful']:
print(report['test_id'], report['successful'], '', '', '', '', '', report['run_time'],
report['template_instances'], sep=',', file=out_csv)
sep=',', file=out_csv)
else:
if report['panicked']:
print(report['test_id'], report['successful'], report['error_class'],
f"\"{report['panicked']['message']}\"",
report['panicked']['file'], report['panicked']['line'], report['panicked']['column'],
report['run_time'], report['template_instances'], sep=',', file=out_csv)
report['run_time'], sep=',', file=out_csv)
else:
print(report['test_id'], report['successful'], report['error_class'], '', '', '', '',
report['run_time'], report['template_instances'], sep=',', file=out_csv)
report['run_time'], sep=',', file=out_csv)
out_csv.flush()


if __name__ == "__main__":
Expand Down

0 comments on commit 2545e71

Please sign in to comment.