Skip to content

Commit

Permalink
Change response format, switch to hierarchical structure
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 13, 2024
1 parent 72c0ac3 commit d04f21f
Show file tree
Hide file tree
Showing 19 changed files with 959 additions and 600 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import org.apache.lucene.search.Explanation;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
Expand Down Expand Up @@ -45,7 +46,11 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) {
public SearchResponse processResponse(
final SearchRequest request,
final SearchResponse response,
final PipelineProcessingContext requestContext
) {
if (Objects.isNull(requestContext)
|| (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))
|| requestContext.getAttribute(EXPLAIN_RESPONSE_KEY) instanceof ExplanationPayload == false) {
Expand All @@ -54,10 +59,6 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationPayload.PayloadType, Object> explainPayload = explanationPayload.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = explanationPayload.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
SearchHits searchHits = response.getHits();
SearchHit[] searchHitsArray = searchHits.getHits();
// create a map of searchShard and list of indexes of search hit objects in search hits array
Expand All @@ -73,29 +74,33 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
}
if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map<?, ?>) {
@SuppressWarnings("unchecked")
Map<SearchShard, List<CombinedExplainDetails>> combinedExplainDetails = (Map<
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = (Map<
SearchShard,
List<CombinedExplainDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);
List<CombinedExplanationDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation normalizedExplanation = Explanation.match(
combinedExplainDetail.getNormalizationExplain().value(),
combinedExplainDetail.getNormalizationExplain().description()
);
Explanation combinedExplanation = Explanation.match(
combinedExplainDetail.getCombinationExplain().value(),
combinedExplainDetail.getCombinationExplain().description()
);

CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation queryLevelExplanation = searchHit.getExplanation();
ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations();
ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations();
Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length];
for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) {
normalizedExplanation[i] = Explanation.match(
// normalized score
normalizationExplanation.scoreDetails().get(i).getKey(),
// description of normalized score
normalizationExplanation.scoreDetails().get(i).getValue(),
// shard level details
queryLevelExplanation.getDetails()[i]
);
}
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
processorExplanation.getDescription(),
normalizedExplanation,
combinedExplanation,
searchHit.getExplanation()
// combination level explanation is always a single detail
combinationExplanation.scoreDetails().get(0).getValue(),
normalizedExplanation
);
searchHit.explanation(finalExplanation);
explainsByShardCount.put(searchShard, explanationIndexByShard);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -14,7 +13,6 @@
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Sort;
Expand All @@ -24,7 +22,7 @@
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
Expand All @@ -42,7 +40,6 @@

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.topLevelExpalantionForCombinedScore;
import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria;

/**
Expand Down Expand Up @@ -113,16 +110,9 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
if (!request.isExplain()) {
return;
}
Explanation topLevelExplanationForTechniques = topLevelExpalantionForCombinedScore(
(ExplainableTechnique) request.getNormalizationTechnique(),
(ExplainableTechnique) request.getCombinationTechnique()
);

// build final result object with all explain related information
if (Objects.nonNull(request.getPipelineProcessingContext())) {

Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs);

Map<DocIdAtSearchShard, ExplanationDetails> normalizationExplain = scoreNormalizer.explain(
queryTopDocs,
(ExplainableTechnique) request.getNormalizationTechnique()
Expand All @@ -132,23 +122,18 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
request.getCombinationTechnique(),
sortForQuery
);
Map<SearchShard, List<CombinedExplainDetails>> combinedExplain = new HashMap<>();

combinationExplain.forEach((searchShard, explainDetails) -> {
for (ExplanationDetails explainDetail : explainDetails) {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), searchShard);
ExplanationDetails normalizedExplanationDetails = normalizationExplain.get(docIdAtSearchShard);
CombinedExplainDetails combinedExplainDetails = CombinedExplainDetails.builder()
.normalizationExplain(normalizedExplanationDetails)
.combinationExplain(explainDetail)
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplanations = combinationExplain.entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> {
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey());
return CombinedExplanationDetails.builder()
.normalizationExplanations(normalizationExplain.get(docIdAtSearchShard))
.combinationExplanations(explainDetail)
.build();
combinedExplain.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(combinedExplainDetails);
}
});
}).collect(Collectors.toList())));

ExplanationPayload explanationPayload = ExplanationPayload.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on arithmetic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on geometrical mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS;
import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.describeCombinationTechnique;
import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

/**
* Abstracts combination of scores based on harmonic mean method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.Objects;
Expand All @@ -16,6 +17,7 @@

import java.util.stream.Collectors;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -27,10 +29,9 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.SearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplainationUtils.getScoreCombinationExplainDetailsForDocument;

/**
* Abstracts combination of scores in query search results.
*/
Expand Down Expand Up @@ -360,10 +361,14 @@ private List<ExplanationDetails> explainByShard(

List<ExplanationDetails> listOfExplanations = sortedDocsIds.stream()
.map(
docId -> getScoreCombinationExplainDetailsForDocument(
docId -> new ExplanationDetails(
docId,
combinedNormalizedScoresByDocId,
normalizedScoresPerDoc.get(docId)
List.of(
Pair.of(
combinedNormalizedScoresByDocId.get(docId),
String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe())
)
)
)
)
.toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
@AllArgsConstructor
@Builder
@Getter
public class CombinedExplainDetails {
private ExplanationDetails normalizationExplain;
private ExplanationDetails combinationExplain;
public class CombinedExplanationDetails {
private ExplanationDetails normalizationExplanations;
private ExplanationDetails combinationExplanations;
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
*/
package org.opensearch.neuralsearch.processor.explain;

import org.apache.commons.lang3.tuple.Pair;

import java.util.List;

/**
* DTO class to store value and description for explain details.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param value
* @param description
* @param docId iterator based id of the document
* @param scoreDetails list of score details for the document, each Pair object contains score and description of the score
*/
public record ExplanationDetails(int docId, float value, String description) {
public ExplanationDetails(float value, String description) {
this(-1, value, description);
public record ExplanationDetails(int docId, List<Pair<Float, String>> scoreDetails) {

public ExplanationDetails(List<Pair<Float, String>> scoreDetails) {
this(-1, scoreDetails);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import org.apache.lucene.search.Explanation;

import java.util.Map;

Expand All @@ -18,7 +17,6 @@
@Builder
@Getter
public class ExplanationPayload {
private final Explanation explanation;
private final Map<PayloadType, Object> explainPayload;

public enum PayloadType {
Expand Down
Loading

0 comments on commit d04f21f

Please sign in to comment.