Skip to content

Commit

Permalink
WIP: Implement concurrent highlighting on multiple threads
Browse files Browse the repository at this point in the history
- Concurrency is per-(doc, field) combination, i.e. when highlighting
  2 field in 10 documents, 20 threads should be used
- Thread limit is currently hardcoded, will be made configurable later
  on
- Includes a benchmarking suite that runs against the example index
  (`example/bench.py`)
  • Loading branch information
jbaiter committed May 10, 2024
1 parent b4df5e7 commit cd48938
Show file tree
Hide file tree
Showing 15 changed files with 501 additions and 88 deletions.
239 changes: 239 additions & 0 deletions example/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#!/usr/bin/env python3
import argparse
import json
import os
import random
import statistics
import sys
import xml.etree.ElementTree as etree

from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from pathlib import Path
from typing import Iterable, Mapping, NamedTuple
from urllib.parse import urlencode
from urllib.request import Request, urlopen
from collections import Counter

NSMAP = {"mets": "http://www.loc.gov/METS/", "mods": "http://www.loc.gov/mods/v3"}


class BenchmarkResult(NamedTuple):
query_times_ms: list[Mapping[str, float]]
hl_times_ms: list[Mapping[str, float]]

def mean_query_time(self) -> float:
return statistics.mean(
statistics.mean(query_times.values()) for query_times in self.query_times_ms
)

def mean_query_times(self) -> tuple[float, ...]:
return tuple(
statistics.mean(query_times.values()) for query_times in self.query_times_ms
)

def mean_hl_time(self) -> float:
return statistics.mean(
statistics.mean(hl_times.values()) for hl_times in self.hl_times_ms
)

def mean_hl_times(self) -> tuple[float, ...]:
return tuple(
statistics.mean(hl_times.values()) for hl_times in self.hl_times_ms
)


def analyze_phrases(
sections: Iterable[tuple[str, ...]],
min_len=2,
max_len=8,
min_word_len=4,
sample_size: int = 512,
) -> Mapping[tuple[str, ...], int]:
counter: Counter[tuple[str, ...]] = Counter()
for words in sections:
for length in range(min_len, max_len + 1):
for i in range(len(words) - length + 1):
phrase = tuple(words[i : i + length])
if all(len(word) < min_word_len for word in phrase):
continue
counter[phrase] += 1
filtered = [(phrase, count) for phrase, count in counter.items() if count > 4]
return dict(
random.sample(
filtered,
min(sample_size, len(filtered)),
)
)


def parse_hocr(hocr_path: Path) -> Iterable[tuple[str, ...]]:
tree = etree.parse(hocr_path)
for block in tree.findall('.//div[@class="ocrx_block"]'):
words = [w for w in block.findall('.//span[@class="ocrx_word"]')]
if len(words) == 0:
continue
passage = tuple(w.text for w in words if w is not None)
if len(passage) == 0:
continue
yield passage


def _queryset_worker(hocr_path: Path) -> Mapping[tuple[str, ...], int]:
return analyze_phrases(parse_hocr(hocr_path))


def build_query_set(
hocr_base_path: Path, min_count=8, max_count=256
) -> Iterable[tuple[str, int]]:
# Counts in how many documents a phrase occurs
phrase_counter = Counter()
with ProcessPoolExecutor(max_workers=cpu_count()) as pool:
futs = [
pool.submit(_queryset_worker, hocr_path)
for hocr_path in hocr_base_path.glob("**/*.hocr")
]
num_completed = 0
for fut in as_completed(futs):
num_completed += 1
print(
f"Analyzed {num_completed:>5}/{len(futs)} documents",
file=sys.stderr,
end="\r",
)
for phrase, count in fut.result().items():
if count < 4:
continue
phrase_counter[phrase] += 1

# Selects phrases that occur in at least min_count documents
for phrase, count in phrase_counter.items():
if count < min_count or count > max_count:
continue
if len(phrase) == 1:
yield phrase[0]
else:
yield f'"{" ".join(phrase)}"', count


def run_query(query: str, solr_handler: str) -> tuple[float, float]:
query_params = {
"q": f"ocr_text:{query}",
"hl": "on",
"hl.ocr.fl": "ocr_text",
"hl.snippets": 5,
"rows": 50,
"debug": "timing",
"hl.weightMatches": "true",
}
req = Request(f"{solr_handler}?{urlencode(query_params)}")
with urlopen(req) as http_resp:
solr_resp = json.load(http_resp)
hl_duration = solr_resp["debug"]["timing"]["process"]["ocrHighlight"]["time"]
query_duration = solr_resp["debug"]["timing"]["time"]
return query_duration, hl_duration


def run_benchmark(
solr_handler: str,
queries: set[str],
iterations=3,
warmup_iters=1,
concurrency=1,
) -> BenchmarkResult:
print(
f"Running benchmark with {iterations} iterations and {concurrency} parallel requests",
file=sys.stderr,
)

with ThreadPoolExecutor(max_workers=concurrency) as pool:
if warmup_iters > 0:
print(f"Running {warmup_iters} warmup iterations", file=sys.stderr)
for idx in range(warmup_iters):
for idx, query in enumerate(queries):
print(
f"Warmup iteration {idx+1}: {idx+1:>4}/{len(queries)}",
file=sys.stderr,
end="\r",
)
run_query(query, solr_handler)

print(f"Running {iterations} benchmark iterations", file=sys.stderr)
all_query_times = []
all_hl_times = []

def _run_query(query):
return query, run_query(query, solr_handler)

for iteration_idx in range(iterations):
iter_futs = [
pool.submit(_run_query, query)
for query in queries
]

query_times = {}
hl_times = {}
for idx, fut in enumerate(as_completed(iter_futs)):
try:
query, (query_time, hl_time) = fut.result()
except Exception as e:
print(f"\nError: {e}", file=sys.stderr)
continue
query_times[query] = query_time
hl_times[query] = hl_time
hl_factor = statistics.mean(hl_times.values()) / statistics.mean(query_times.values())
print(
f"Iteration {iteration_idx+1}: {idx+1:>4}/{len(queries)}, "
f"øq={statistics.mean(query_times.values()):.2f}ms, "
f"øhl={statistics.mean(hl_times.values()):.2f}ms, "
f"hl_share={hl_factor:.2f}",
file=sys.stderr,
end="\r",
)
all_query_times.append(query_times)
all_hl_times.append(hl_times)

return BenchmarkResult(all_query_times, all_hl_times)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--iterations", type=int, default=3)
parser.add_argument("--warmup-iterations", type=int, default=1)
parser.add_argument("--concurrency", type=int, default=1)
parser.add_argument("--queries-path", type=str, default="./benchqueries.txt")
parser.add_argument("--save-results", type=str, default=None)
parser.add_argument("--solr-handler", type=str, default="http://localhost:8983/solr/ocr/select")
args = parser.parse_args()

if os.path.exists(args.queries_path):
with open(args.queries_path, "r") as f:
queries = set(x for x in f.read().split("\n") if x.strip())
else:
hocr_base_path = Path("./data/google1000")
queries = set(q for q, _ in build_query_set(hocr_base_path))
with open(args.queries_path, "w") as f:
f.write("\n".join(queries))

results = run_benchmark(
args.solr_handler,
queries,
iterations=args.iterations,
warmup_iters=args.warmup_iterations,
concurrency=args.concurrency,
)

print("\n\n=====================================")
print(f"Mean query time: {results.mean_query_time():.2f}ms")
print(f"Mean highlighting time: {results.mean_hl_time():.2f}ms")
print(f"Percent of time spent on highlighting: {results.mean_hl_time() / results.mean_query_time() * 100:.2f}%")

if args.save_results:
with open(args.save_results, "w") as f:
json.dump(
{
"query_times": results.query_times_ms,
"hl_times": results.hl_times_ms,
},
f,
)
4 changes: 3 additions & 1 deletion example/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ services:
- index-data:/var/solr
- ./data:/data
- ../target:/build
- ./flightrecords:/flightrecords
environment:
- ENABLE_REMOTE_JMX_OPTS=true
- SOLR_HEAP=4g
- ADDITIONAL_CMD_OPTS=-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:1044
- ADDITIONAL_CMD_OPTS=-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=*:1044 -XX:StartFlightRecording=settings=profile,filename=/flightrecords/profile.jfr,maxage=30m -XX:+UnlockDiagnosticVMOptions -XX:+DebugNonSafepoints -XX:+PreserveFramePointer

- SOLR_SECURITY_MANAGER_ENABLED=false
entrypoint:
- docker-entrypoint.sh
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<str>query</str>
<str>ocrHighlight</str>
<str>highlight</str>
<str>debug</str>
</arr>
</requestHandler>
</config>
17 changes: 5 additions & 12 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
<version.maven-javadoc-plugin>3.6.3</version.maven-javadoc-plugin>
<version.maven-shade-plugin>3.5.1</version.maven-shade-plugin>
<version.maven-source-plugin>3.3.0</version.maven-source-plugin>
<version.junit-platform-maven>1.1.7</version.junit-platform-maven>
<version.surefire>3.2.5</version.surefire>
<version.nexus-staging-maven-plugin>1.6.13</version.nexus-staging-maven-plugin>
</properties>

Expand Down Expand Up @@ -181,18 +181,11 @@
</executions>
</plugin>
<plugin>
<groupId>de.sormuras.junit</groupId>
<artifactId>junit-platform-maven-plugin</artifactId>
<version>${version.junit-platform-maven}</version>
<extensions>true</extensions>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>${version.surefire}</version>
<configuration>
<executor>JAVA</executor>
<javaOptions>
<inheritIO>true</inheritIO>
<additionalOptions>
<additionalOption>-Djava.security.egd=file:/dev/./urandom</additionalOption>
</additionalOptions>
</javaOptions>
<argLine>-Djava.security.egd=file:/dev/./urandom</argLine>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
package com.github.dbmdz.solrocr.solr;

import com.github.dbmdz.solrocr.model.OcrHighlightResult;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.Query;
import org.apache.solr.common.params.HighlightParams;
import org.apache.solr.common.params.SolrParams;
Expand All @@ -47,6 +51,28 @@
public class SolrOcrHighlighter extends UnifiedSolrHighlighter {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final ThreadPoolExecutor hlThreadPool;

public SolrOcrHighlighter() {
this(Runtime.getRuntime().availableProcessors(), 8);
}

public SolrOcrHighlighter(int numHlThreads, int maxQueuedPerThread) {
super();
this.hlThreadPool =
new ThreadPoolExecutor(
numHlThreads,
numHlThreads,
120L,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(numHlThreads * maxQueuedPerThread),
new ThreadFactoryBuilder().setNameFormat("OcrHighlighter-%d").build());
}

public void shutdownThreadPool() {
hlThreadPool.shutdown();
}

public NamedList<Object> doHighlighting(
DocList docs, Query query, SolrQueryRequest req, Map<String, Object> respHeader)
throws IOException {
Expand Down Expand Up @@ -75,7 +101,8 @@ public NamedList<Object> doHighlighting(
OcrHighlighter ocrHighlighter =
new OcrHighlighter(req.getSearcher(), req.getSchema().getIndexAnalyzer(), req);
OcrHighlightResult[] ocrSnippets =
ocrHighlighter.highlightOcrFields(ocrFieldNames, query, docIDs, maxPassagesOcr, respHeader);
ocrHighlighter.highlightOcrFields(
ocrFieldNames, query, docIDs, maxPassagesOcr, respHeader, hlThreadPool);

// Assemble output data
SimpleOrderedMap<Object> out = new SimpleOrderedMap<>();
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/solrocr/ExternalUtf8ContentFilterFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ public Reader create(Reader input) {
pointer.sources.forEach(this::validateSource);

// Regions contained in source pointers are defined by byte offsets.
// We need to convert these to Java character offsets so they can be used by the filter.
// We need to convert these to Java character offsets, so they can be used by the filter.
// This is very expensive, but we need this since all IO from here on out is character-based.
toCharOffsets(pointer);

Reader r;
Expand Down
10 changes: 9 additions & 1 deletion src/main/java/solrocr/OcrHighlightComponent.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ public void prepare(ResponseBuilder rb) throws IOException {

@Override
public void inform(SolrCore core) {
this.ocrHighlighter = new SolrOcrHighlighter();
int numHlThreads =
Integer.parseInt(
info.attributes.getOrDefault(
"numHighlightingThreads",
String.valueOf(Runtime.getRuntime().availableProcessors())));
int maxQueuedPerThread =
Integer.parseInt(info.attributes.getOrDefault("maxQueuedPerThread", "8"));
this.ocrHighlighter = new SolrOcrHighlighter(numHlThreads, maxQueuedPerThread);
if ("true".equals(info.attributes.getOrDefault("enablePreload", "false"))) {
PageCacheWarmer.enable(
Integer.parseInt(info.attributes.getOrDefault("preloadReadSize", "32768")),
Expand All @@ -95,6 +102,7 @@ public void preClose(SolrCore core) {}

@Override
public void postClose(SolrCore core) {
ocrHighlighter.shutdownThreadPool();
PageCacheWarmer.getInstance().ifPresent(PageCacheWarmer::shutdown);
}
});
Expand Down
Loading

0 comments on commit cd48938

Please sign in to comment.