diff --git a/.semversioner/next-release/patch-20240821235154401001.json b/.semversioner/next-release/patch-20240821235154401001.json new file mode 100644 index 0000000000..b544e564cf --- /dev/null +++ b/.semversioner/next-release/patch-20240821235154401001.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add context data to query API responses." +} diff --git a/graphrag/query/api.py b/graphrag/query/api.py index 050af3b729..788de8c0ca 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -9,9 +9,9 @@ Contains the following functions: - global_search: Perform a global search. - - global_search_streaming: Perform a global search and stream results via a generator. + - global_search_streaming: Perform a global search and stream results back. - local_search: Perform a local search. - - local_search_streaming: Perform a local search and stream results via a generator. + - local_search_streaming: Perform a local search and stream results back. WARNING: This API is under development and may undergo changes in future releases. Backwards compatibility is not guaranteed at this time. @@ -26,6 +26,7 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.progress.types import PrintProgressReporter from graphrag.model.entity import Entity +from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 from graphrag.vector_stores.lancedb import LanceDBVectorStore from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType @@ -51,8 +52,11 @@ async def global_search( community_level: int, response_type: str, query: str, -) -> str | dict[str, Any] | list[dict[str, Any]]: - """Perform a global search. +) -> tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], +]: + """Perform a global search and return the context data and response. Parameters ---------- @@ -80,9 +84,10 @@ async def global_search( entities=_entities, response_type=response_type, ) - result = await search_engine.asearch(query=query) - reporter.success(f"Global Search Response: {result.response}") - return result.response + result: SearchResult = await search_engine.asearch(query=query) + response = result.response + context_data = _reformat_context_data(result.context_data) # type: ignore + return response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -95,7 +100,7 @@ async def global_search_streaming( response_type: str, query: str, ) -> AsyncGenerator: - """Perform a global search and return results as a generator. + """Perform a global search and return the context data and response via a generator. Context data is returned as a dictionary of lists, with one list entry for each record. @@ -152,8 +157,11 @@ async def local_search( community_level: int, response_type: str, query: str, -) -> str | dict[str, Any] | list[dict[str, Any]]: - """Perform a local search. +) -> tuple[ + str | dict[str, Any] | list[dict[str, Any]], + str | list[pd.DataFrame] | dict[str, pd.DataFrame], +]: + """Perform a local search and return the context data and response. Parameters ---------- @@ -203,9 +211,10 @@ async def local_search( response_type=response_type, ) - result = await search_engine.asearch(query=query) - reporter.success(f"Local Search Response: {result.response}") - return result.response + result: SearchResult = await search_engine.asearch(query=query) + response = result.response + context_data = _reformat_context_data(result.context_data) # type: ignore + return response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -221,7 +230,7 @@ async def local_search_streaming( response_type: str, query: str, ) -> AsyncGenerator: - """Perform a local search and return results as a generator. + """Perform a local search and return the context data and response via a generator. Parameters ---------- diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 59eb4b454a..de625e912a 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -78,7 +78,7 @@ async def run_streaming_search(): return asyncio.run(run_streaming_search()) # not streaming - return asyncio.run( + response, context_data = asyncio.run( api.global_search( config=config, nodes=final_nodes, @@ -89,6 +89,10 @@ async def run_streaming_search(): query=query, ) ) + reporter.success(f"Global Search Response:\n{response}") + # NOTE: we return the response and context data here purely as a complete demonstration of the API. + # External users should use the API directly to get the response and context data. + return response, context_data def run_local_search( @@ -156,7 +160,7 @@ async def run_streaming_search(): return asyncio.run(run_streaming_search()) # not streaming - return asyncio.run( + response, context_data = asyncio.run( api.local_search( config=config, nodes=final_nodes, @@ -170,6 +174,10 @@ async def run_streaming_search(): query=query, ) ) + reporter.success(f"Local Search Response:\n{response}") + # NOTE: we return the response and context data here purely as a complete demonstration of the API. + # External users should use the API directly to get the response and context data. + return response, context_data def _configure_paths_and_settings(