Skip to content

Commit

Permalink
Unify precomputation of aggregations behind a common API
Browse files Browse the repository at this point in the history
We've had a series of aggregation speedups that use the same strategy:
instead of iterating through documents that match the query
one-by-one, we can look at a Lucene segment and compute the
aggregation directly (if some particular conditions are met).

In every case, we've hooked that into custom logic hijacks the
getLeafCollector method and throws CollectionTerminatedException. This
creates the illusion that we're implementing a custom LeafCollector,
when really we're not collecting at all (which is the whole point).

With this refactoring, the mechanism (hijacking getLeafCollector) is
moved into AggregatorBase. Aggregators that have a strategy to
precompute their answer can override tryPrecomputeAggregationForLeaf,
which is expected to return true if they managed to precompute.

This should also make it easier to keep track of which aggregations
have precomputation approaches (since they override this method).

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed Nov 27, 2024
1 parent 5068fad commit 4d5c32b
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.lucene.Lucene;
Expand All @@ -27,8 +26,6 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.aggregations.AggregatorFactory;
import org.opensearch.search.aggregations.LeafBucketCollector;
import org.opensearch.search.aggregations.LeafBucketCollectorBase;
import org.opensearch.search.aggregations.metrics.MetricAggregatorFactory;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -177,11 +174,10 @@ public static StarTreeValues getStarTreeValues(LeafReaderContext context, Compos
* Get the star-tree leaf collector
* This collector computes the aggregation prematurely and invokes an early termination collector
*/
public static LeafBucketCollector getStarTreeLeafCollector(
public static void precomputeAggregationFromStarTree(
SearchContext context,
ValuesSource.Numeric valuesSource,
LeafReaderContext ctx,
LeafBucketCollector sub,
CompositeIndexFieldInfo starTree,
String metric,
Consumer<Long> valueConsumer,
Expand Down Expand Up @@ -221,14 +217,6 @@ public static LeafBucketCollector getStarTreeLeafCollector(

// Call the final consumer after processing all entries
finalConsumer.run();

// Return a LeafBucketCollector that terminates collection
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
@Override
public void collect(int doc, long bucket) {
throw new CollectionTerminatedException();
}
};
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
package org.opensearch.search.aggregations;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.common.breaker.CircuitBreaker;
Expand Down Expand Up @@ -200,6 +201,9 @@ public Map<String, Object> metadata() {

@Override
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
if (tryPrecomputeAggregationForLeaf(ctx)) {
throw new CollectionTerminatedException();
}
preGetSubLeafCollectors(ctx);
final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx);
return getLeafCollector(ctx, sub);
Expand All @@ -216,6 +220,10 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException
*/
protected void doPreCollection() throws IOException {}

protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
return false;
}

@Override
public final void preCollection() throws IOException {
List<BucketCollector> collectors = Arrays.asList(subAggregators);
Expand Down Expand Up @@ -251,8 +259,8 @@ public Aggregator[] subAggregators() {
public Aggregator subAggregator(String aggName) {
if (subAggregatorbyName == null) {
subAggregatorbyName = new HashMap<>(subAggregators.length);
for (int i = 0; i < subAggregators.length; i++) {
subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]);
for (Aggregator subAggregator : subAggregators) {
subAggregatorbyName.put(subAggregator.name(), subAggregator);
}
}
return subAggregatorbyName.get(aggName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,15 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (subAggregators().length == 0) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}
return false;
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
finishLeaf();

boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.CollectionUtil;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -161,15 +160,17 @@ public ScoreMode scoreMode() {
return super.scoreMode();
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();

SortedNumericDocValues values = valuesSource.longValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
package org.opensearch.search.aggregations.bucket.range;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -310,10 +309,15 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (segmentMatchAll(context, ctx)) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {

final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -166,35 +165,32 @@ public void setWeight(Weight weight) {
@return A LeafBucketCollector implementation with collection termination, since collection is complete
@throws IOException If an I/O error occurs during reading
*/
LeafBucketCollector termDocFreqCollector(
LeafReaderContext ctx,
SortedSetDocValues globalOrds,
BiConsumer<Long, Integer> ordCountConsumer
) throws IOException {
boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer<Long, Integer> ordCountConsumer)
throws IOException {
if (weight == null) {
// Weight not assigned - cannot use this optimization
return null;
return false;
} else {
if (weight.count(ctx) == 0) {
// No documents matches top level query on this segment, we can skip the segment entirely
return LeafBucketCollector.NO_OP_COLLECTOR;
return true;
} else if (weight.count(ctx) != ctx.reader().maxDoc()) {
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
// top-level query matches all docs in the segment
return null;
return false;
}
}

Terms segmentTerms = ctx.reader().terms(this.fieldName);
if (segmentTerms == null) {
// Field is not indexed.
return null;
return false;
}

NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
// This segment has at least one document with the _doc_count field.
return null;
return false;
}

TermsEnum indexTermsEnum = segmentTerms.iterator();
Expand All @@ -218,31 +214,28 @@ LeafBucketCollector termDocFreqCollector(
ordinalTerm = globalOrdinalTermsEnum.next();
}
}
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
throw new CollectionTerminatedException();
}
};
return true;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

if (collectionStrategy instanceof DenseGlobalOrds
&& this.resultStrategy instanceof StandardTermsResults
&& sub == LeafBucketCollector.NO_OP_COLLECTOR) {
LeafBucketCollector termDocFreqCollector = termDocFreqCollector(
&& subAggregators.length == 0) {
return tryCollectFromTermFrequencies(
ctx,
globalOrds,
(ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds);
if (singleValues != null) {
Expand Down Expand Up @@ -433,6 +426,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator {
this.segmentDocCounts = context.bigArrays().newLongArray(1, true);
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (subAggregators.length == 0) {
if (mapping != null) {
mapSegmentCountsToGlobalCounts(mapping);
}
final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx);
segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount());
mapping = valuesSource.globalOrdinalsMapping(ctx);
return tryCollectFromTermFrequencies(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (mapping != null) {
Expand All @@ -443,17 +454,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
assert sub == LeafBucketCollector.NO_OP_COLLECTOR;
mapping = valuesSource.globalOrdinalsMapping(ctx);

if (this.resultStrategy instanceof StandardTermsResults) {
LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}

final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds);
if (singleValues != null) {
segmentsWithSingleValuedOrds++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,27 @@ public ScoreMode scoreMode() {
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context);
if (supportedStarTree != null) {
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
StarTreeQueryHelper.precomputeAggregationFromStarTree(
context,
valuesSource,
ctx,
supportedStarTree,
MetricStat.MAX.getTypeName(),
value -> {
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
},
() -> maxes.set(0, max.get())
);
return true;
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
Expand All @@ -128,15 +149,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}
}

CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context);
if (supportedStarTree != null) {
return getStarTreeCollector(ctx, sub, supportedStarTree);
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
Expand All @@ -160,23 +172,6 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public LeafBucketCollector getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
return StarTreeQueryHelper.getStarTreeLeafCollector(
context,
valuesSource,
ctx,
sub,
starTree,
MetricStat.MAX.getTypeName(),
value -> {
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
},
() -> maxes.set(0, max.get())
);
}

@Override
public double metric(long owningBucketOrd) {
if (valuesSource == null || owningBucketOrd >= maxes.size()) {
Expand Down
Loading

0 comments on commit 4d5c32b

Please sign in to comment.