From 90987810004b22ce8bb65b1d726d14a9dd49e47b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 14 Nov 2024 21:35:04 -0800 Subject: [PATCH] Address revie comments Signed-off-by: Martin Gaievski --- .../ExplanationResponseProcessor.java | 17 +++- .../NormalizationProcessorWorkflow.java | 20 ++-- .../processor/combination/ScoreCombiner.java | 26 ++--- .../processor/explain/ExplanationDetails.java | 1 + .../processor/explain/ExplanationUtils.java | 23 +++-- .../query/HybridQueryExplainIT.java | 97 +++++++++++++++++++ 6 files changed, 154 insertions(+), 30 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 01c1516d2..01cdfcb0d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -40,11 +40,17 @@ public class ExplanationResponseProcessor implements SearchResponseProcessor { private final String tag; private final boolean ignoreFailure; + /** + * Add explanation details to search response if it is present in request context + */ @Override public SearchResponse processResponse(SearchRequest request, SearchResponse response) { return processResponse(request, response, null); } + /** + * Combines explanation from processor with search hits level explanations and adds it to search response + */ @Override public SearchResponse processResponse( final SearchRequest request, @@ -56,15 +62,20 @@ public SearchResponse processResponse( || requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) { return response; } + // Extract explanation payload from context ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY); Map explainPayload = explanationPayload.getExplainPayload(); if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { + // for score normalization, processor level explanations will be sorted in scope of each shard, + // and we are merging both into a single sorted list SearchHits searchHits = response.getHits(); SearchHit[] searchHitsArray = searchHits.getHits(); // create a map of searchShard and list of indexes of search hit objects in search hits array // the list will keep original order of sorting as per final search results Map> searchHitsByShard = new HashMap<>(); + // we keep index for each shard, where index is a position in searchHitsByShard list Map explainsByShardCount = new HashMap<>(); + // Build initial shard mappings for (int i = 0; i < searchHitsArray.length; i++) { SearchHit searchHit = searchHitsArray[i]; SearchShardTarget searchShardTarget = searchHit.getShard(); @@ -72,19 +83,22 @@ public SearchResponse processResponse( searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); explainsByShardCount.putIfAbsent(searchShard, -1); } + // Process normalization details if available in correct format if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map) { @SuppressWarnings("unchecked") Map> combinedExplainDetails = (Map< SearchShard, List>) explainPayload.get(NORMALIZATION_PROCESSOR); - + // Process each search hit to add processor level explanations for (SearchHit searchHit : searchHitsArray) { SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + // Extract various explanation components Explanation queryLevelExplanation = searchHit.getExplanation(); ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations(); ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations(); + // Create normalized explanations for each detail Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length]; for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { normalizedExplanation[i] = Explanation.match( @@ -96,6 +110,7 @@ public SearchResponse processResponse( queryLevelExplanation.getDetails()[i] ); } + // Create and set final explanation combining all components Explanation finalExplanation = Explanation.match( searchHit.getScore(), // combination level explanation is always a single detail diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 078c68aff..f2699d967 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -106,6 +107,10 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); } + /** + * Collects explanations from normalization and combination techniques and save thme into pipeline context. Later that + * information will be read by the response processor to add it to search response + */ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List queryTopDocs) { if (!request.isExplain()) { return; @@ -122,15 +127,19 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< request.getCombinationTechnique(), sortForQuery ); - Map> combinedExplanations = combinationExplain.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> { + Map> combinedExplanations = new HashMap<>(); + for (Map.Entry> entry : combinationExplain.entrySet()) { + List combinedDetailsList = new ArrayList<>(); + for (ExplanationDetails explainDetail : entry.getValue()) { DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey()); - return CombinedExplanationDetails.builder() + CombinedExplanationDetails combinedDetail = CombinedExplanationDetails.builder() .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) .combinationExplanations(explainDetail) .build(); - }).collect(Collectors.toList()))); + combinedDetailsList.add(combinedDetail); + } + combinedExplanations.put(entry.getKey(), combinedDetailsList); + } ExplanationPayload explanationPayload = ExplanationPayload.builder() .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations)) @@ -139,7 +148,6 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload); } - } /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index cbc3f485b..1779f20f7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -359,19 +359,19 @@ private List explainByShard( // sort combined scores as per sorting criteria - either score desc or field sorting Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - List listOfExplanations = sortedDocsIds.stream() - .map( - docId -> new ExplanationDetails( - docId, - List.of( - Pair.of( - combinedNormalizedScoresByDocId.get(docId), - String.format(Locale.ROOT, "%s combination of:", ((ExplainableTechnique) scoreCombinationTechnique).describe()) - ) - ) - ) - ) - .toList(); + List listOfExplanations = new ArrayList<>(); + String combinationDescription = String.format( + Locale.ROOT, + "%s combination of:", + ((ExplainableTechnique) scoreCombinationTechnique).describe() + ); + for (int docId : sortedDocsIds) { + ExplanationDetails explanation = new ExplanationDetails( + docId, + List.of(Pair.of(combinedNormalizedScoresByDocId.get(docId), combinationDescription)) + ); + listOfExplanations.add(explanation); + } return listOfExplanations; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index e577e6f43..2816a348b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -21,6 +21,7 @@ public class ExplanationDetails { List> scoreDetails; public ExplanationDetails(List> scoreDetails) { + // pass docId as -1 to match docId in SearchHit https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchHit.java#L170 this(-1, scoreDetails); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java index b4c5cd557..c6ac0500b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -6,12 +6,13 @@ import org.apache.commons.lang3.tuple.Pair; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.stream.Collectors; /** * Utility class for explain functionality @@ -27,15 +28,17 @@ public static Map getDocIdAtQueryForNorm final Map> normalizedScores, final ExplainableTechnique technique ) { - Map explain = normalizedScores.entrySet() - .stream() - .collect(Collectors.toMap(Map.Entry::getKey, entry -> { - List normScores = normalizedScores.get(entry.getKey()); - List> explanations = normScores.stream() - .map(score -> Pair.of(score, String.format(Locale.ROOT, "%s normalization of:", technique.describe()))) - .collect(Collectors.toList()); - return new ExplanationDetails(explanations); - })); + Map explain = new HashMap<>(); + for (Map.Entry> entry : normalizedScores.entrySet()) { + List normScores = normalizedScores.get(entry.getKey()); + List> explanations = new ArrayList<>(); + for (float score : normScores) { + String description = String.format(Locale.ROOT, "%s normalization of:", technique.describe()); + explanations.add(Pair.of(score, description)); + } + explain.put(entry.getKey(), new ExplanationDetails(explanations)); + } + return explain; } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index a7656912c..3b1d6cfba 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -17,6 +17,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.stream.IntStream; @@ -37,6 +38,7 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index"; + private static final String TEST_LARGE_DOCS_INDEX_NAME = "test-hybrid-large-docs-index"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "place"; @@ -459,6 +461,64 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe } } + @SneakyThrows + public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_LARGE_DOCS_INDEX_NAME, + hybridQueryBuilder, + null, + 1000, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + assertNotNull(hitsNestedList); + assertFalse(hitsNestedList.isEmpty()); + + // Verify total hits + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // Sanity checks for each hit's explanation + for (Map hit : hitsNestedList) { + // Verify score is positive + double score = (double) hit.get("_score"); + assertTrue("Score should be positive", score > 0.0); + + // Basic explanation structure checks + Map explanation = (Map) hit.get("_explanation"); + assertNotNull(explanation); + assertEquals("arithmetic_mean combination of:", explanation.get("description")); + Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); + assertTrue((double) hitDetailsForHit.get("value") > 0.0f); + assertEquals("min_max normalization of:", hitDetailsForHit.get("description")); + Map subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0); + assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f); + assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty()); + assertEquals(1, getListOfValues(subQueryDetailsForHit, "details").size()); + } + // Verify scores are properly ordered + List scores = new ArrayList<>(); + for (Map hit : hitsNestedList) { + scores.add((Double) hit.get("_score")); + } + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); + } finally { + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { @@ -521,6 +581,43 @@ private void initializeIndexIfNotExist(String indexName) { ); addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); } + + if (TEST_LARGE_DOCS_INDEX_NAME.equals(indexName) && !indexExists(TEST_LARGE_DOCS_INDEX_NAME)) { + prepareKnnIndex( + TEST_LARGE_DOCS_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + + // Index 1000 documents + for (int i = 0; i < 1000; i++) { + String docText; + if (i % 5 == 0) { + docText = TEST_DOC_TEXT1; // "Hello world" + } else if (i % 7 == 0) { + docText = TEST_DOC_TEXT2; // "Hi to this place" + } else if (i % 11 == 0) { + docText = TEST_DOC_TEXT3; // "We would like to welcome everyone" + } else { + docText = String.format(Locale.ROOT, "Document %d with random content", i); + } + + addKnnDoc( + TEST_LARGE_DOCS_INDEX_NAME, + String.valueOf(i), + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of( + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray(), + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray() + ), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(docText) + ); + } + assertEquals(1000, getDocCount(TEST_LARGE_DOCS_INDEX_NAME)); + } } private void addDocsToIndex(final String testMultiDocIndexName) {