From 319c59430b0320e24570a6bcb1dbd4f60d796fe9 Mon Sep 17 00:00:00 2001 From: Niklas Koehnecke Date: Tue, 23 Jan 2024 17:46:45 +0100 Subject: [PATCH] IL-167 Add ids to RetrieverBasedQa answers --- .../use_cases/qa/retriever_based_qa.py | 49 +++++++++++++++++-- tests/use_cases/qa/test_retriever_based_qa.py | 11 +++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/intelligence_layer/use_cases/qa/retriever_based_qa.py b/src/intelligence_layer/use_cases/qa/retriever_based_qa.py index bef139452..9ba16ac6e 100644 --- a/src/intelligence_layer/use_cases/qa/retriever_based_qa.py +++ b/src/intelligence_layer/use_cases/qa/retriever_based_qa.py @@ -1,3 +1,5 @@ +from typing import Generic, Optional, Sequence + from pydantic import BaseModel from intelligence_layer.connectors.limited_concurrency_client import ( @@ -11,7 +13,7 @@ from intelligence_layer.use_cases.qa.multiple_chunk_qa import ( MultipleChunkQa, MultipleChunkQaInput, - MultipleChunkQaOutput, + Subanswer, ) from intelligence_layer.use_cases.search.search import Search, SearchInput @@ -29,7 +31,25 @@ class RetrieverBasedQaInput(BaseModel): language: Language = Language("en") -class RetrieverBasedQa(Task[RetrieverBasedQaInput, MultipleChunkQaOutput]): +class EnrichedSubanswer(Subanswer, Generic[ID]): + id: ID + + +class RetrieverBasedQaOutput(BaseModel, Generic[ID]): + """The output of a `RetrieverBasedQa` task. + + Attributes: + answer: The answer generated by the task. Can be a string or None (if no answer was found). + subanswers: All the subanswers used to generate the answer. + """ + + answer: Optional[str] + subanswers: Sequence[EnrichedSubanswer[ID]] + + +class RetrieverBasedQa( + Task[RetrieverBasedQaInput, RetrieverBasedQaOutput[ID]], Generic[ID] +): """Answer a question based on documents found by a retriever. `RetrieverBasedQa` is a task that answers a question based on a set of documents. @@ -78,7 +98,7 @@ def __init__( def do_run( self, input: RetrieverBasedQaInput, task_span: TaskSpan - ) -> MultipleChunkQaOutput: + ) -> RetrieverBasedQaOutput[ID]: search_output = self._search.run(SearchInput(query=input.question), task_span) multi_chunk_qa_input = MultipleChunkQaInput( @@ -88,4 +108,25 @@ def do_run( question=input.question, language=input.language, ) - return self._multi_chunk_qa.run(multi_chunk_qa_input, task_span) + + result = self._multi_chunk_qa.run(multi_chunk_qa_input, task_span) + + # multi_chunk_qa does not known IDs so we need to rematch them + text_to_id = { + document.document_chunk.text: document.id + for document in search_output.results + } + enriched_answers = [ + EnrichedSubanswer( + answer=answer.answer, + chunk=answer.chunk, + highlights=answer.highlights, + id=text_to_id[answer.chunk], + ) + for answer in result.subanswers + ] + correctly_formatted_output = RetrieverBasedQaOutput( + answer=result.answer, + subanswers=enriched_answers, + ) + return correctly_formatted_output diff --git a/tests/use_cases/qa/test_retriever_based_qa.py b/tests/use_cases/qa/test_retriever_based_qa.py index 9639c6d87..02850fee2 100644 --- a/tests/use_cases/qa/test_retriever_based_qa.py +++ b/tests/use_cases/qa/test_retriever_based_qa.py @@ -2,6 +2,7 @@ from pytest import fixture +from intelligence_layer.connectors.document_index.document_index import DocumentPath from intelligence_layer.connectors.limited_concurrency_client import ( AlephAlphaClientProtocol, ) @@ -41,7 +42,7 @@ def in_memory_retriever_documents() -> Sequence[Document]: def retriever_based_qa_with_in_memory_retriever( client: AlephAlphaClientProtocol, asymmetric_in_memory_retriever: QdrantInMemoryRetriever, -) -> RetrieverBasedQa: +) -> RetrieverBasedQa[int]: return RetrieverBasedQa( client, asymmetric_in_memory_retriever, model="luminous-base-control" ) @@ -50,14 +51,14 @@ def retriever_based_qa_with_in_memory_retriever( @fixture def retriever_based_qa_with_document_index( client: AlephAlphaClientProtocol, document_index_retriever: DocumentIndexRetriever -) -> RetrieverBasedQa: +) -> RetrieverBasedQa[DocumentPath]: return RetrieverBasedQa( client, document_index_retriever, model="luminous-base-control" ) def test_retriever_based_qa_using_in_memory_retriever( - retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa, + retriever_based_qa_with_in_memory_retriever: RetrieverBasedQa[int], no_op_tracer: NoOpTracer, ) -> None: question = "When was Robert Moses born?" @@ -65,10 +66,11 @@ def test_retriever_based_qa_using_in_memory_retriever( output = retriever_based_qa_with_in_memory_retriever.run(input, no_op_tracer) assert output.answer assert "1888" in output.answer + assert output.subanswers[0].id == 3 def test_retriever_based_qa_with_document_index( - retriever_based_qa_with_document_index: RetrieverBasedQa, + retriever_based_qa_with_document_index: RetrieverBasedQa[DocumentPath], no_op_tracer: NoOpTracer, ) -> None: question = "When was Robert Moses born?" @@ -76,3 +78,4 @@ def test_retriever_based_qa_with_document_index( output = retriever_based_qa_with_document_index.run(input, no_op_tracer) assert output.answer assert "1888" in output.answer + assert output.subanswers[0].id.document_name == "Robert Moses (Begriffsklärung)"