diff --git a/.semversioner/next-release/patch-20241126205530514149.json b/.semversioner/next-release/patch-20241126205530514149.json new file mode 100644 index 000000000..23198b082 --- /dev/null +++ b/.semversioner/next-release/patch-20241126205530514149.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Fix Global Search with dynamic Community selection bug" +} diff --git a/graphrag/query/context_builder/dynamic_community_selection.py b/graphrag/query/context_builder/dynamic_community_selection.py index a17cfdc27..6680cfe6d 100644 --- a/graphrag/query/context_builder/dynamic_community_selection.py +++ b/graphrag/query/context_builder/dynamic_community_selection.py @@ -73,11 +73,12 @@ def __init__( } # mapping from level to communities self.levels: dict[str, list[str]] = {} + for community in communities: if community.level not in self.levels: self.levels[community.level] = [] - if community.id in self.reports: - self.levels[community.level].append(community.id) + if community.short_id in self.reports: + self.levels[community.level].append(community.short_id) # start from root communities (level 0) self.starting_communities = self.levels["0"] @@ -100,6 +101,7 @@ async def select(self, query: str) -> tuple[list[CommunityReport], dict[str, Any "output_tokens": 0, } relevant_communities = set() + while queue: gather_results = await asyncio.gather(*[ rate_relevancy( diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index b854569b9..ae69e2b01 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -105,12 +105,11 @@ def get_global_search_engine( dynamic_community_selection_kwargs = {} if dynamic_community_selection: - gs_config = config.global_search - _config = deepcopy(config) - _config.llm.model = _config.llm.deployment_name = gs_config.dynamic_search_llm + # TODO: Allow for another llm definition only for Global Search to leverage -mini models + dynamic_community_selection_kwargs.update({ - "llm": get_llm(_config), - "token_encoder": tiktoken.encoding_for_model(gs_config.dynamic_search_llm), + "llm": get_llm(config), + "token_encoder": tiktoken.encoding_for_model(config.llm.model), "keep_parent": gs_config.dynamic_search_keep_parent, "num_repeats": gs_config.dynamic_search_num_repeats, "use_summary": gs_config.dynamic_search_use_summary, @@ -119,6 +118,8 @@ def get_global_search_engine( "max_level": gs_config.dynamic_search_max_level, }) + print(dynamic_community_selection_kwargs) + return GlobalSearch( llm=get_llm(config), map_system_prompt=map_system_prompt,