Skip to content

Commit

Permalink
Doing some refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 4, 2024
1 parent 9340557 commit a19de09
Show file tree
Hide file tree
Showing 17 changed files with 601 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.ProcessorExplainPublisherFactory;
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
Expand Down Expand Up @@ -185,8 +185,8 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchRespon
return Map.of(
RerankProcessor.TYPE,
new RerankProcessorFactory(clientAccessor, parameters.searchPipelineService.getClusterService()),
ExplainResponseProcessor.TYPE,
new ProcessorExplainPublisherFactory()
ExplanationResponseProcessor.TYPE,
new ExplanationResponseProcessorFactory()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ private void initialize(TotalHits totalHits, List<TopDocs> topDocs, boolean isSo
public CompoundTopDocs(final QuerySearchResult querySearchResult) {
final TopDocs topDocs = querySearchResult.topDocs().topDocs;
final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget();
SearchShard searchShard = new SearchShard(
searchShardTarget.getIndex(),
searchShardTarget.getShardId().id(),
searchShardTarget.getNodeId()
);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
boolean isSortEnabled = false;
if (topDocs instanceof TopFieldDocs) {
isSortEnabled = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.neuralsearch.processor.explain.CombinedExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
Expand All @@ -24,13 +24,16 @@
import java.util.Objects;

import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLAIN_RESPONSE_KEY;
import static org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR;
import static org.opensearch.neuralsearch.processor.explain.ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR;

/**
* Processor to add explanation details to search response
*/
@Getter
@AllArgsConstructor
public class ExplainResponseProcessor implements SearchResponseProcessor {
public class ExplanationResponseProcessor implements SearchResponseProcessor {

public static final String TYPE = "explain_response_processor";
public static final String TYPE = "explanation_response_processor";

private final String description;
private final String tag;
Expand All @@ -46,10 +49,10 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
if (Objects.isNull(requestContext) || (Objects.isNull(requestContext.getAttribute(EXPLAIN_RESPONSE_KEY)))) {
return response;
}
ProcessorExplainDto processorExplainDto = (ProcessorExplainDto) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ProcessorExplainDto.ExplanationType, Object> explainPayload = processorExplainDto.getExplainPayload();
ExplanationResponse explanationResponse = (ExplanationResponse) requestContext.getAttribute(EXPLAIN_RESPONSE_KEY);
Map<ExplanationResponse.ExplanationType, Object> explainPayload = explanationResponse.getExplainPayload();
if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) {
Explanation processorExplanation = processorExplainDto.getExplanation();
Explanation processorExplanation = explanationResponse.getExplanation();
if (Objects.isNull(processorExplanation)) {
return response;
}
Expand All @@ -62,7 +65,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
for (int i = 0; i < searchHitsArray.length; i++) {
SearchHit searchHit = searchHitsArray[i];
SearchShardTarget searchShardTarget = searchHit.getShard();
SearchShard searchShard = SearchShard.create(searchShardTarget);
SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget);
searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i);
explainsByShardCount.putIfAbsent(searchShard, -1);
}
Expand All @@ -73,7 +76,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
List<CombinedExplainDetails>>) explainPayload.get(NORMALIZATION_PROCESSOR);

for (SearchHit searchHit : searchHitsArray) {
SearchShard searchShard = SearchShard.create(searchHit.getShard());
SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard());
int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1;
CombinedExplainDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard);
Explanation normalizedExplanation = Explanation.match(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
Expand All @@ -53,19 +53,27 @@ public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.empty());
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
PipelineProcessingContext requestContext
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
doProcessStuff(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void doProcessStuff(
private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ProcessorExplainDto;
import org.opensearch.neuralsearch.processor.explain.ExplanationResponse;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -146,13 +146,13 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List<
}
});

ProcessorExplainDto processorExplainDto = ProcessorExplainDto.builder()
ExplanationResponse explanationResponse = ExplanationResponse.builder()
.explanation(topLevelExplanationForTechniques)
.explainPayload(Map.of(ProcessorExplainDto.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.explainPayload(Map.of(ExplanationResponse.ExplanationType.NORMALIZATION_PROCESSOR, combinedExplain))
.build();
// store explain object to pipeline context
PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext();
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, processorExplainDto);
pipelineProcessingContext.setAttribute(EXPLAIN_RESPONSE_KEY, explanationResponse);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
@Builder
@AllArgsConstructor
@Getter
/**
* DTO class to hold request parameters for normalization and combination
*/
public class NormalizationProcessorWorkflowExecuteRequest {
final List<QuerySearchResult> querySearchResults;
final Optional<FetchSearchResult> fetchSearchResultOptional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

import org.opensearch.search.SearchShardTarget;

/**
* DTO class to store index, shardId and nodeId for a search shard.
*/
public record SearchShard(String index, int shardId, String nodeId) {

public static SearchShard create(SearchShardTarget searchShardTarget) {
/**
* Create SearchShard from SearchShardTarget
* @param searchShardTarget
* @return SearchShard
*/
public static SearchShard createSearchShard(final SearchShardTarget searchShardTarget) {
return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import lombok.Builder;
import lombok.Getter;

/**
* DTO class to hold explain details for normalization and combination
*/
@AllArgsConstructor
@Builder
@Getter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import org.opensearch.neuralsearch.processor.SearchShard;

/**
* Data class to store docId and search shard for a query.
* DTO class to store docId and search shard for a query.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param docId
* @param searchShard
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
package org.opensearch.neuralsearch.processor.explain;

/**
* Data class to store value and description for explain details.
* DTO class to store value and description for explain details.
* Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards.
* @param value
* @param description
*/
public record ExplainDetails(float value, String description, int docId) {

public record ExplainDetails(int docId, float value, String description) {
public ExplainDetails(float value, String description) {
this(value, description, -1);
this(-1, value, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ public static ExplainDetails getScoreCombinationExplainDetailsForDocument(
) {
float combinedScore = combinedNormalizedScoresByDocId.get(docId);
return new ExplainDetails(
docId,
combinedScore,
String.format(
Locale.ROOT,
"normalized scores: %s combined to a final score: %s",
Arrays.toString(normalizedScoresPerDoc),
combinedScore
),
docId
)
);
}

Expand Down Expand Up @@ -96,5 +96,4 @@ public static Explanation topLevelExpalantionForCombinedScore(

return Explanation.match(0.0f, explanationDetailsMessage);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

import java.util.Map;

/**
* DTO class to hold explain details for normalization and combination
*/
@AllArgsConstructor
@Builder
@Getter
public class ProcessorExplainDto {
public class ExplanationResponse {
Explanation explanation;
Map<ExplanationType, Object> explainPayload;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
*/
package org.opensearch.neuralsearch.processor.factory;

import org.opensearch.neuralsearch.processor.ExplainResponseProcessor;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import java.util.Map;

public class ProcessorExplainPublisherFactory implements Processor.Factory<SearchResponseProcessor> {
/**
* Factory class for creating ExplanationResponseProcessor
*/
public class ExplanationResponseProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

@Override
public SearchResponseProcessor create(
Expand All @@ -21,6 +24,6 @@ public SearchResponseProcessor create(
Map<String, Object> config,
Processor.PipelineContext pipelineContext
) throws Exception {
return new ExplainResponseProcessor(description, tag, ignoreFailure);
return new ExplanationResponseProcessor(description, tag, ignoreFailure);
}
}

This file was deleted.

Loading

0 comments on commit a19de09

Please sign in to comment.