Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement concurrent highlighting on multiple threads (#64) #429

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions example/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#!/usr/bin/env python3
import argparse
import json
import os
import random
import statistics
import sys
import xml.etree.ElementTree as etree

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from pathlib import Path
from typing import Iterable, Mapping, NamedTuple
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from collections import Counter

NSMAP = {"mets": "http://www.loc.gov/METS/", "mods": "http://www.loc.gov/mods/v3"}


class BenchmarkResult(NamedTuple):
query_times_ms: list[Mapping[str, float]]
hl_times_ms: list[Mapping[str, float]]

def mean_query_time(self) -> float:
return statistics.mean(
statistics.mean(query_times.values()) for query_times in self.query_times_ms
)

def mean_query_times(self) -> tuple[float, ...]:
return tuple(
statistics.mean(query_times.values()) for query_times in self.query_times_ms
)

def mean_hl_time(self) -> float:
return statistics.mean(
statistics.mean(hl_times.values()) for hl_times in self.hl_times_ms
)

def mean_hl_times(self) -> tuple[float, ...]:
return tuple(
statistics.mean(hl_times.values()) for hl_times in self.hl_times_ms
)


def analyze_phrases(
sections: Iterable[tuple[str, ...]],
min_len=2,
max_len=8,
min_word_len=4,
sample_size: int = 512,
) -> Mapping[tuple[str, ...], int]:
counter: Counter[tuple[str, ...]] = Counter()
for words in sections:
for length in range(min_len, max_len + 1):
for i in range(len(words) - length + 1):
phrase = tuple(words[i : i + length])
if all(len(word) < min_word_len for word in phrase):
continue
counter[phrase] += 1
filtered = [(phrase, count) for phrase, count in counter.items() if count > 4]
return dict(
random.sample(
filtered,
min(sample_size, len(filtered)),
)
)


def parse_hocr(hocr_path: Path) -> Iterable[tuple[str, ...]]:
tree = etree.parse(hocr_path)
for block in tree.findall('.//div[@class="ocrx_block"]'):
words = [w for w in block.findall('.//span[@class="ocrx_word"]')]
if len(words) == 0:
continue
passage = tuple(w.text for w in words if w is not None)
if len(passage) == 0:
continue
yield passage


def _queryset_worker(hocr_path: Path) -> Mapping[tuple[str, ...], int]:
return analyze_phrases(parse_hocr(hocr_path))


def build_query_set(
hocr_base_path: Path, min_count=8, max_count=256
) -> Iterable[tuple[str, int]]:
# Counts in how many documents a phrase occurs
phrase_counter = Counter()
with ProcessPoolExecutor(max_workers=cpu_count()) as pool:
futs = [
pool.submit(_queryset_worker, hocr_path)
for hocr_path in hocr_base_path.glob("**/*.hocr")
]
num_completed = 0
for fut in as_completed(futs):
num_completed += 1
print(
f"Analyzed {num_completed:>5}/{len(futs)} documents",
file=sys.stderr,
end="\r",
)
for phrase, count in fut.result().items():
if count < 4:
continue
phrase_counter[phrase] += 1

# Selects phrases that occur in at least min_count documents
for phrase, count in phrase_counter.items():
if count < min_count or count > max_count:
continue
if len(phrase) == 1:
yield phrase[0]
else:
yield f'"{" ".join(phrase)}"', count


def run_query(query: str, solr_handler: str) -> tuple[float, float]:
query_params = {
"q": f"ocr_text:{query}",
"hl": "on",
"hl.ocr.fl": "ocr_text",
"hl.snippets": 5,
"rows": 50,
"debug": "timing",
"hl.weightMatches": "true",
}
req = Request(f"{solr_handler}?{urlencode(query_params)}")
with urlopen(req) as http_resp:
solr_resp = json.load(http_resp)
hl_duration = solr_resp["debug"]["timing"]["process"]["ocrHighlight"]["time"]
query_duration = solr_resp["debug"]["timing"]["time"]
return query_duration, hl_duration


def run_benchmark(
solr_handler: str,
queries: set[str],
iterations=3,
warmup_iters=1,
concurrency=1,
) -> BenchmarkResult:
print(
f"Running benchmark with {iterations} iterations and {concurrency} parallel requests",
file=sys.stderr,
)

with ThreadPoolExecutor(max_workers=concurrency) as pool:
if warmup_iters > 0:
print(f"Running {warmup_iters} warmup iterations", file=sys.stderr)
for idx in range(warmup_iters):
for idx, query in enumerate(queries):
print(
f"Warmup iteration {idx+1}: {idx+1:>4}/{len(queries)}",
file=sys.stderr,
end="\r",
)
run_query(query, solr_handler)

print(f"Running {iterations} benchmark iterations", file=sys.stderr)
all_query_times = []
all_hl_times = []

def _run_query(query):
return query, run_query(query, solr_handler)

for iteration_idx in range(iterations):
iter_futs = [
pool.submit(_run_query, query)
for query in queries
]

query_times = {}
hl_times = {}
for idx, fut in enumerate(as_completed(iter_futs)):
try:
query, (query_time, hl_time) = fut.result()
except Exception as e:
print(f"\nError: {e}", file=sys.stderr)
continue
query_times[query] = query_time
hl_times[query] = hl_time
hl_factor = statistics.mean(hl_times.values()) / statistics.mean(query_times.values())
print(
f"Iteration {iteration_idx+1}: {idx+1:>4}/{len(queries)}, "
f"øq={statistics.mean(query_times.values()):.2f}ms, "
f"øhl={statistics.mean(hl_times.values()):.2f}ms, "
f"hl_share={hl_factor:.2f}",
file=sys.stderr,
end="\r",
)
all_query_times.append(query_times)
all_hl_times.append(hl_times)

return BenchmarkResult(all_query_times, all_hl_times)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--iterations", type=int, default=3)
parser.add_argument("--warmup-iterations", type=int, default=1)
parser.add_argument("--concurrency", type=int, default=1)
parser.add_argument("--queries-path", type=str, default="./benchqueries.txt")
parser.add_argument("--save-results", type=str, default=None)
parser.add_argument("--solr-handler", type=str, default="http://localhost:8983/solr/ocr/select")
args = parser.parse_args()

if os.path.exists(args.queries_path):
with open(args.queries_path, "r") as f:
queries = set(x for x in f.read().split("\n") if x.strip())
else:
hocr_base_path = Path("./data/google1000")
queries = set(q for q, _ in build_query_set(hocr_base_path))
with open(args.queries_path, "w") as f:
f.write("\n".join(queries))

results = run_benchmark(
args.solr_handler,
queries,
iterations=args.iterations,
warmup_iters=args.warmup_iterations,
concurrency=args.concurrency,
)

print("\n\n=====================================")
print(f"Mean query time: {results.mean_query_time():.2f}ms")
print(f"Mean highlighting time: {results.mean_hl_time():.2f}ms")
print(f"Percent of time spent on highlighting: {results.mean_hl_time() / results.mean_query_time() * 100:.2f}%")

if args.save_results:
with open(args.save_results, "w") as f:
json.dump(
{
"query_times": results.query_times_ms,
"hl_times": results.hl_times_ms,
},
f,
)
4 changes: 3 additions & 1 deletion example/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ services:
- index-data:/var/solr
- ./data:/data
- ../target:/build
- ./flightrecords:/flightrecords
environment:
- ENABLE_REMOTE_JMX_OPTS=true
- SOLR_HEAP=4g
- ADDITIONAL_CMD_OPTS=-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:1044
- ADDITIONAL_CMD_OPTS=-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:1044 -XX:StartFlightRecording=settings=profile,filename=/flightrecords/profile.jfr,maxage=30m -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints -XX:+PreserveFramePointer

- SOLR_SECURITY_MANAGER_ENABLED=false
entrypoint:
- docker-entrypoint.sh
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<str>query</str>
<str>ocrHighlight</str>
<str>highlight</str>
<str>debug</str>
</arr>
</requestHandler>
</config>
17 changes: 5 additions & 12 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
<version.maven-javadoc-plugin>3.6.3</version.maven-javadoc-plugin>
<version.maven-shade-plugin>3.5.1</version.maven-shade-plugin>
<version.maven-source-plugin>3.3.0</version.maven-source-plugin>
<version.junit-platform-maven>1.1.7</version.junit-platform-maven>
<version.surefire>3.2.5</version.surefire>
<version.nexus-staging-maven-plugin>1.6.13</version.nexus-staging-maven-plugin>
</properties>

Expand Down Expand Up @@ -181,18 +181,11 @@
</executions>
</plugin>
<plugin>
<groupId>de.sormuras.junit</groupId>
<artifactId>junit-platform-maven-plugin</artifactId>
<version>${version.junit-platform-maven}</version>
<extensions>true</extensions>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>${version.surefire}</version>
<configuration>
<executor>JAVA</executor>
<javaOptions>
<inheritIO>true</inheritIO>
<additionalOptions>
<additionalOption>-Djava.security.egd=file:/dev/./urandom</additionalOption>
</additionalOptions>
</javaOptions>
<argLine>-Djava.security.egd=file:/dev/./urandom</argLine>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public int read(char[] cbuf, int off, int len) throws IOException {
while (idx < (off + numRead)) {
// Check for invalid entities and try to fix them
while (advancedFixing && idx < (off + numRead)) {
int match = multiIndexOf(cbuf, idx, '<', '&');
int match = multiIndexOf(cbuf, idx, (off + numRead), '<', '&');
if (match < 0 || match > (off + numRead)) {
// Nothing to do in this buffer
break outer;
Expand All @@ -99,13 +99,13 @@ public int read(char[] cbuf, int off, int len) throws IOException {
// Start of element, no more entities to check
break;
}
int entityEnd = multiIndexOf(cbuf, match + 1, '<', ';');
int entityEnd = multiIndexOf(cbuf, match + 1, (off + numRead), '<', ';');

if (entityEnd < match + 1) {
// Not enough data to determine entity end, we may have to carry over
// FIXME: This code is largely identical to the carry-over case below, find a way to
// deduplicate
int carryOverSize = numRead - (match + 1);
int carryOverSize = (off + numRead) - (match + 1);
if (carryOverSize == numRead) {
// Can't carry-over the whole buffer, since the read would return 0
// We bail out of validation since there's no way we can force the caller to request
Expand Down Expand Up @@ -205,7 +205,7 @@ public int read(char[] cbuf, int off, int len) throws IOException {
(cbuf[startElem + 1] == '/' || cbuf[startElem + 1] == '?')
? startElem + 2
: startElem + 1;
int endTag = multiIndexOf(cbuf, startTag, ' ', '\n', '\t');
int endTag = multiIndexOf(cbuf, startTag, (off + numRead), ' ', '\n', '\t');
if (endTag > endElem || endTag < 0) {
endTag = cbuf[endElem - 1] == '/' ? endElem - 1 : endElem;
}
Expand Down Expand Up @@ -290,14 +290,14 @@ public int read(char[] cbuf, int off, int len) throws IOException {
* Variant of {@link org.apache.commons.lang3.ArrayUtils#indexOf(char[], char)} that supports
* looking for multiple values.
*/
private static int multiIndexOf(final char[] array, int startIndex, final char... valuesToFind) {
private static int multiIndexOf(final char[] array, int startIndex, int limit, final char... valuesToFind) {
if (array == null) {
return -1;
}
if (startIndex < 0) {
startIndex = 0;
}
for (int i = startIndex; i < array.length; i++) {
for (int i = startIndex; i < limit; i++) {
for (char value : valuesToFind) {
if (value == array[i]) {
return i;
Expand Down
Loading
Loading