Skip to content

Commit

Permalink
Add context data to query responses (#1003)
Browse files Browse the repository at this point in the history
* add context data to query responses

* add semversioner file

* ignore typechecking ruff suggestion
  • Loading branch information
jgbradley1 authored Aug 22, 2024
1 parent 9c6f5e0 commit 4b9fdc0
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 16 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240821235154401001.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add context data to query API responses."
}
37 changes: 23 additions & 14 deletions graphrag/query/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
Contains the following functions:
- global_search: Perform a global search.
- global_search_streaming: Perform a global search and stream results via a generator.
- global_search_streaming: Perform a global search and stream results back.
- local_search: Perform a local search.
- local_search_streaming: Perform a local search and stream results via a generator.
- local_search_streaming: Perform a local search and stream results back.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
Expand All @@ -26,6 +26,7 @@
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType

Expand All @@ -51,8 +52,11 @@ async def global_search(
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a global search.
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a global search and return the context data and response.
Parameters
----------
Expand Down Expand Up @@ -80,9 +84,10 @@ async def global_search(
entities=_entities,
response_type=response_type,
)
result = await search_engine.asearch(query=query)
reporter.success(f"Global Search Response: {result.response}")
return result.response
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
Expand All @@ -95,7 +100,7 @@ async def global_search_streaming(
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a global search and return results as a generator.
"""Perform a global search and return the context data and response via a generator.
Context data is returned as a dictionary of lists, with one list entry for each record.
Expand Down Expand Up @@ -152,8 +157,11 @@ async def local_search(
community_level: int,
response_type: str,
query: str,
) -> str | dict[str, Any] | list[dict[str, Any]]:
"""Perform a local search.
) -> tuple[
str | dict[str, Any] | list[dict[str, Any]],
str | list[pd.DataFrame] | dict[str, pd.DataFrame],
]:
"""Perform a local search and return the context data and response.
Parameters
----------
Expand Down Expand Up @@ -203,9 +211,10 @@ async def local_search(
response_type=response_type,
)

result = await search_engine.asearch(query=query)
reporter.success(f"Local Search Response: {result.response}")
return result.response
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
context_data = _reformat_context_data(result.context_data) # type: ignore
return response, context_data


@validate_call(config={"arbitrary_types_allowed": True})
Expand All @@ -221,7 +230,7 @@ async def local_search_streaming(
response_type: str,
query: str,
) -> AsyncGenerator:
"""Perform a local search and return results as a generator.
"""Perform a local search and return the context data and response via a generator.
Parameters
----------
Expand Down
12 changes: 10 additions & 2 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def run_streaming_search():

return asyncio.run(run_streaming_search())
# not streaming
return asyncio.run(
response, context_data = asyncio.run(
api.global_search(
config=config,
nodes=final_nodes,
Expand All @@ -89,6 +89,10 @@ async def run_streaming_search():
query=query,
)
)
reporter.success(f"Global Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data


def run_local_search(
Expand Down Expand Up @@ -156,7 +160,7 @@ async def run_streaming_search():

return asyncio.run(run_streaming_search())
# not streaming
return asyncio.run(
response, context_data = asyncio.run(
api.local_search(
config=config,
nodes=final_nodes,
Expand All @@ -170,6 +174,10 @@ async def run_streaming_search():
query=query,
)
)
reporter.success(f"Local Search Response:\n{response}")
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data


def _configure_paths_and_settings(
Expand Down

0 comments on commit 4b9fdc0

Please sign in to comment.