Skip to content

Commit

Permalink
Improve search type hint (#1031)
Browse files Browse the repository at this point in the history
* update get_local_search_engine and get_global_search_engine return annotation

* add semversioner file

* reorder imports

* fix pyright errors

* revert change and ignore previous pyright error

---------

Co-authored-by: wanhua.gu <[email protected]>
Co-authored-by: longyunfeigu <[email protected]>
Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
4 people authored Aug 26, 2024
1 parent 4c2f537 commit a90d210
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240826152927762829.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Update query type hints."
}
4 changes: 1 addition & 3 deletions graphrag/index/verbs/graph/merge/merge_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def merge_edges(
target_graph.add_edge(source, target, **(edge_data or {}))
else:
merge_attributes(
target_graph.edges[
(source, target) # noqa: RUF031 Parenthesis needed, false positive
],
target_graph.edges[(source, target)], # noqa
edge_data,
edge_ops,
)
Expand Down
4 changes: 2 additions & 2 deletions graphrag/query/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ async def global_search_streaming(
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
Expand Down Expand Up @@ -301,7 +301,7 @@ async def local_search_streaming(
get_context_data = True
async for stream_chunk in search_result:
if get_context_data:
context_data = _reformat_context_data(stream_chunk)
context_data = _reformat_context_data(stream_chunk) # type: ignore
yield context_data
get_context_data = False
else:
Expand Down
5 changes: 3 additions & 2 deletions graphrag/query/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.structured_search.base import BaseSearch
from graphrag.query.structured_search.global_search.community_context import (
GlobalCommunityContext,
)
Expand Down Expand Up @@ -108,7 +109,7 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
) -> LocalSearch:
) -> BaseSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
text_embedder = get_text_embedder(config)
Expand Down Expand Up @@ -159,7 +160,7 @@ def get_global_search_engine(
reports: list[CommunityReport],
entities: list[Entity],
response_type: str,
):
) -> BaseSearch:
"""Create a global search engine based on data + configuration."""
token_encoder = tiktoken.get_encoding(config.encoding_model)
gs_config = config.global_search
Expand Down

0 comments on commit a90d210

Please sign in to comment.