Skip to content

Commit

Permalink
Update benchmarking script
Browse files Browse the repository at this point in the history
  • Loading branch information
jbaiter committed May 23, 2024
1 parent 6558eea commit 4d1f5a0
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 43 deletions.
167 changes: 124 additions & 43 deletions example/bench.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,45 @@
#!/usr/bin/env python3
"""Small benchmarking script for Solr highlighting performance.
Generates a set of common two-term phrase queries from the Google 1000 dataset
and runs them against the Dockerized example setup. The script measures the time
spent on query execution and highlighting and prints the results to stdout.
If you want to profile the plugin:
- Download async-profiler: https://github.com/async-profiler/async-profiler
- Mount the async-profiler directory to the same location in the container as
on your system
- Add these lines to the `solr` service in `docker-compose.yml`:
```
security_opt:
- seccomp:unconfined
cap_add:
- SYS_ADMIN
```
- Launch the container
- Find the PID of the Solr process on the host machine (use `ps` or `htop`)
- Launch the profiler: `${ASYNC_PROFILER_DIR}/asprof -d 60 -f /tmp/flamegraph.svg ${SOLRPID}`
"""

import argparse
import json
import os
import random
import statistics
import string
import sys
import xml.etree.ElementTree as etree

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from pathlib import Path
from typing import Iterable, Mapping, NamedTuple
from typing import Iterable, Mapping, NamedTuple, TextIO, cast
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from collections import Counter

NSMAP = {"mets": "http://www.loc.gov/METS/", "mods": "http://www.loc.gov/mods/v3"}
STRIP_PUNCTUATION_TBL = str.maketrans("", "", string.punctuation)


class BenchmarkResult(NamedTuple):
Expand Down Expand Up @@ -73,24 +97,28 @@ def parse_hocr(hocr_path: Path) -> Iterable[tuple[str, ...]]:
words = [w for w in block.findall('.//span[@class="ocrx_word"]')]
if len(words) == 0:
continue
passage = tuple(w.text for w in words if w is not None)
passage = tuple(
filtered
for filtered in (
w.text.translate(STRIP_PUNCTUATION_TBL)
for w in words
if w is not None and w.text is not None
)
if filtered
)
if len(passage) == 0:
continue
yield passage


def _queryset_worker(hocr_path: Path) -> Mapping[tuple[str, ...], int]:
return analyze_phrases(parse_hocr(hocr_path))


def build_query_set(
hocr_base_path: Path, min_count=8, max_count=256
) -> Iterable[tuple[str, int]]:
# Counts in how many documents a phrase occurs
phrase_counter = Counter()
with ProcessPoolExecutor(max_workers=cpu_count()) as pool:
futs = [
pool.submit(_queryset_worker, hocr_path)
pool.submit(lambda p: analyze_phrases(parse_hocr(p)), hocr_path)
for hocr_path in hocr_base_path.glob("**/*.hocr")
]
num_completed = 0
Expand All @@ -116,13 +144,15 @@ def build_query_set(
yield f'"{" ".join(phrase)}"', count


def run_query(query: str, solr_handler: str) -> tuple[float, float]:
def run_query(
query: str, solr_handler: str, num_rows: int, num_snippets: int
) -> tuple[float, float]:
query_params = {
"q": f"ocr_text:{query}",
"hl": "on",
"hl.ocr.fl": "ocr_text",
"hl.snippets": 5,
"rows": 50,
"hl.snippets": num_snippets,
"rows": num_rows,
"debug": "timing",
"hl.weightMatches": "true",
}
Expand All @@ -138,38 +168,25 @@ def run_benchmark(
solr_handler: str,
queries: set[str],
iterations=3,
warmup_iters=1,
concurrency=1,
num_rows=50,
num_snippets=5,
) -> BenchmarkResult:
print(
f"Running benchmark with {iterations} iterations and {concurrency} parallel requests",
f"Running benchmark for {num_rows} rows with {num_snippets} snippets across {iterations} iterations and {concurrency} parallel requests",
file=sys.stderr,
)

with ThreadPoolExecutor(max_workers=concurrency) as pool:
if warmup_iters > 0:
print(f"Running {warmup_iters} warmup iterations", file=sys.stderr)
for idx in range(warmup_iters):
for idx, query in enumerate(queries):
print(
f"Warmup iteration {idx+1}: {idx+1:>4}/{len(queries)}",
file=sys.stderr,
end="\r",
)
run_query(query, solr_handler)

print(f"Running {iterations} benchmark iterations", file=sys.stderr)
all_query_times = []
all_hl_times = []

def _run_query(query):
return query, run_query(query, solr_handler)
return query, run_query(query, solr_handler, num_rows, num_snippets)

for iteration_idx in range(iterations):
iter_futs = [
pool.submit(_run_query, query)
for query in queries
]
iter_futs = [pool.submit(_run_query, query) for query in queries]

query_times = {}
hl_times = {}
Expand All @@ -181,7 +198,9 @@ def _run_query(query):
continue
query_times[query] = query_time
hl_times[query] = hl_time
hl_factor = statistics.mean(hl_times.values()) / statistics.mean(query_times.values())
hl_factor = statistics.mean(hl_times.values()) / statistics.mean(
query_times.values()
)
print(
f"Iteration {iteration_idx+1}: {idx+1:>4}/{len(queries)}, "
f"øq={statistics.mean(query_times.values()):.2f}ms, "
Expand All @@ -197,36 +216,98 @@ def _run_query(query):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--iterations", type=int, default=3)
parser.add_argument("--warmup-iterations", type=int, default=1)
parser.add_argument("--concurrency", type=int, default=1)
parser.add_argument("--queries-path", type=str, default="./benchqueries.txt")
parser.add_argument("--save-results", type=str, default=None)
parser.add_argument("--solr-handler", type=str, default="http://localhost:8983/solr/ocr/select")
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"--iterations",
type=int,
default=3,
metavar="N",
help="Number of benchmark iterations",
)
parser.add_argument(
"--concurrency",
type=int,
default=1,
metavar="N",
help="Number of concurrent requests",
)
parser.add_argument(
"--queries-path",
type=str,
default="./benchqueries.txt.gz",
metavar="PATH",
help="Path to the file containing the queries",
)
parser.add_argument(
"--save-results",
type=str,
default=None,
metavar="PATH",
help="Path to save the results to as a JSON file (optional)",
)
parser.add_argument(
"--num-rows",
type=int,
default=50,
metavar="N",
help="Number of rows to request from Solr",
)
parser.add_argument(
"--num-snippets",
type=int,
default=5,
metavar="N",
help="Number of snippets to request from Solr",
)
parser.add_argument(
"--solr-handler",
type=str,
default="http://localhost:8983/solr/ocr/select",
help="URL to the Solr handler",
)
args = parser.parse_args()

if os.path.exists(args.queries_path):
with open(args.queries_path, "r") as f:
queries = set(x for x in f.read().split("\n") if x.strip())
if args.queries_path.endswith(".gz"):
import gzip

with gzip.open(args.queries_path, "rt") as f:
queries = set(
q for q in (line.strip() for line in cast(TextIO, f)) if q
)
else:
with open(args.queries_path, "rt") as f:
queries = set(q for q in (line.strip() for line in f) if q)
else:
hocr_base_path = Path("./data/google1000")
queries = set(q for q, _ in build_query_set(hocr_base_path))
with open(args.queries_path, "w") as f:
f.write("\n".join(queries))
if args.queries_path.endswith(".gz"):
import gzip

with cast(TextIO, gzip.open(args.queries_path, "wt", compresslevel=9)) as f:
f.write("\n".join(queries))
else:
with open(args.queries_path, "w") as f:
f.write("\n".join(queries))

results = run_benchmark(
args.solr_handler,
queries,
iterations=args.iterations,
warmup_iters=args.warmup_iterations,
concurrency=args.concurrency,
num_rows=args.num_rows,
num_snippets=args.num_snippets,
)

print("\n\n=====================================")
print(f"Mean query time: {results.mean_query_time():.2f}ms")
print(f"Mean highlighting time: {results.mean_hl_time():.2f}ms")
print(f"Percent of time spent on highlighting: {results.mean_hl_time() / results.mean_query_time() * 100:.2f}%")
print(
f"Percent of time spent on highlighting: {results.mean_hl_time() / results.mean_query_time() * 100:.2f}%"
)
print("=====================================\n\n")

if args.save_results:
with open(args.save_results, "w") as f:
Expand All @@ -236,4 +317,4 @@ def _run_query(query):
"hl_times": results.hl_times_ms,
},
f,
)
)
Binary file added example/benchqueries.txt.gz
Binary file not shown.

0 comments on commit 4d1f5a0

Please sign in to comment.