diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 3b52ecbd8c..dec0584781 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -6,6 +6,7 @@ import asyncio import json import logging +import re import time from dataclasses import dataclass from typing import Any @@ -92,6 +93,7 @@ def __init__( self.map_llm_params = map_llm_params self.reduce_llm_params = reduce_llm_params + self.json_mode = json_mode if json_mode: self.map_llm_params["response_format"] = {"type": "json_object"} else: @@ -174,6 +176,9 @@ async def _map_response_single_batch( search_prompt = "" try: search_prompt = self.map_system_prompt.format(context_data=context_data) + if self.json_mode: + search_prompt += "\nYour response should be in JSON format." + self._last_search_prompt = search_prompt search_messages = [ {"role": "system", "content": search_prompt}, {"role": "user", "content": query}, @@ -228,15 +233,37 @@ def parse_search_response(self, search_response: str) -> list[dict[str, Any]]: ------- list[dict[str, Any]] A list of key points, each key point is a dictionary with "answer" and "score" keys + + Raises + ------ + ValueError + If the response cannot be parsed as JSON or doesn't contain the expected structure """ - parsed_elements = json.loads(search_response)["points"] - return [ - { - "answer": element["description"], - "score": int(element["score"]), - } - for element in parsed_elements - ] + # Try to extract JSON from the response if it's embedded in text + json_match = re.search(r"\{.*\}", search_response, re.DOTALL) + json_str = json_match.group() if json_match else search_response + + try: + parsed_data = json.loads(json_str) + except json.JSONDecodeError as err: + error_msg = "Failed to parse response as JSON" + raise ValueError(error_msg) from err + + if "points" not in parsed_data or not isinstance(parsed_data["points"], list): + error_msg = "Response JSON does not contain a 'points' list" + raise ValueError(error_msg) + + try: + return [ + { + "answer": element["description"], + "score": int(element["score"]), + } + for element in parsed_data["points"] + ] + except (KeyError, ValueError) as e: + error_msg = f"Error processing points: {e!s}" + raise ValueError(error_msg) from e async def _reduce_response( self, @@ -316,6 +343,8 @@ async def _reduce_response( ) if self.allow_general_knowledge: search_prompt += "\n" + self.general_knowledge_inclusion_prompt + if self.json_mode: + search_prompt += "\nYour response should be in JSON format." search_messages = [ {"role": "system", "content": search_prompt}, {"role": "user", "content": query}, diff --git a/tests/unit/query/structured_search/__init__.py b/tests/unit/query/structured_search/__init__.py new file mode 100644 index 0000000000..0a3e38adfb --- /dev/null +++ b/tests/unit/query/structured_search/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License diff --git a/tests/unit/query/structured_search/test_search.py b/tests/unit/query/structured_search/test_search.py new file mode 100644 index 0000000000..72930dff19 --- /dev/null +++ b/tests/unit/query/structured_search/test_search.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphrag.query.context_builder.builders import GlobalContextBuilder +from graphrag.query.llm.base import BaseLLM +from graphrag.query.structured_search.base import SearchResult +from graphrag.query.structured_search.global_search.search import GlobalSearch + + +class MockLLM(BaseLLM): + def __init__(self): + self.call_count = 0 + self.last_messages = None + + def generate( + self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs + ): + self.call_count += 1 + self.last_messages = messages + return "mocked response" + + async def agenerate( + self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs + ): + self.call_count += 1 + self.last_messages = messages + return "mocked response" + + +class MockContextBuilder(GlobalContextBuilder): + def build_context(self, conversation_history=None, **kwargs): + return ["mocked context"], {} + + +class TestGlobalSearch(GlobalSearch): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_search_prompt = None + + def get_last_search_prompt(self): + return self._last_search_prompt + + def set_last_search_prompt(self, value): + self._last_search_prompt = value + + +@pytest.fixture +def global_search(): + llm = MockLLM() + context_builder = MockContextBuilder() + return TestGlobalSearch(llm, context_builder) + + +@pytest.mark.asyncio +async def test_json_format_instruction_in_search_prompt(global_search): + global_search.json_mode = True + query = "Test query" + + global_search.context_builder.build_context = MagicMock( + return_value=(["mocked context"], {}) + ) + + global_search.parse_search_response = MagicMock( + return_value=[{"answer": "Mocked answer", "score": 100}] + ) + + await global_search.asearch(query) + + assert global_search.get_last_search_prompt() is not None + assert ( + "Your response should be in JSON format." + in global_search.get_last_search_prompt() + ) + + global_search.json_mode = False + global_search.set_last_search_prompt(None) + await global_search.asearch(query) + + assert global_search.get_last_search_prompt() is not None + assert ( + "Your response should be in JSON format." + not in global_search.get_last_search_prompt() + ) + + +def test_parse_search_response_valid(global_search): + valid_response = """ + { + "points": [ + {"description": "Point 1", "score": 90}, + {"description": "Point 2", "score": 80} + ] + } + """ + result = global_search.parse_search_response(valid_response) + assert len(result) == 2 + assert result[0] == {"answer": "Point 1", "score": 90} + assert result[1] == {"answer": "Point 2", "score": 80} + + +def test_parse_search_response_invalid_json(global_search): + invalid_json = "This is not JSON" + with pytest.raises(ValueError, match="Failed to parse response as JSON"): + global_search.parse_search_response(invalid_json) + + +def test_parse_search_response_missing_points(global_search): + missing_points = '{"data": "No points here"}' + with pytest.raises( + ValueError, match="Response JSON does not contain a 'points' list" + ): + global_search.parse_search_response(missing_points) + + +def test_parse_search_response_invalid_point_format(global_search): + invalid_point = """ + { + "points": [ + {"wrong_key": "Point 1", "score": 90} + ] + } + """ + with pytest.raises(ValueError, match="Error processing points"): + global_search.parse_search_response(invalid_point) + + +def test_parse_search_response_with_text_prefix(global_search): + response_with_prefix = """ + Here's the response: + { + "points": [ + {"description": "Point 1", "score": 90} + ] + } + """ + result = global_search.parse_search_response(response_with_prefix) + assert len(result) == 1 + assert result[0] == {"answer": "Point 1", "score": 90} + + +def test_parse_search_response_non_integer_score(global_search): + non_integer_score = """ + { + "points": [ + {"description": "Point 1", "score": "high"} + ] + } + """ + with pytest.raises(ValueError, match="Error processing points"): + global_search.parse_search_response(non_integer_score) + + +@pytest.mark.asyncio +async def test_map_response(global_search): + global_search.map_response = AsyncMock( + return_value=[ + SearchResult( + response=[{"answer": "Test answer", "score": 90}], + context_data="Test context", + context_text="Test context", + completion_time=0.1, + llm_calls=1, + prompt_tokens=10, + ) + ] + ) + + context_data = ["Test context"] + query = "Test query" + result = await global_search.map_response(context_data, query) + + assert isinstance(result[0], SearchResult) + assert result[0].context_data == "Test context" + assert result[0].context_text == "Test context" + assert result[0].llm_calls == 1 + + +@pytest.mark.asyncio +async def test_reduce_response(global_search): + global_search.reduce_response = AsyncMock( + return_value=SearchResult( + response="Final answer", + context_data="Combined context", + context_text="Combined context", + completion_time=0.2, + llm_calls=1, + prompt_tokens=20, + ) + ) + + map_responses = [ + SearchResult( + response=[{"answer": "Point 1", "score": 90}], + context_data="", + context_text="", + completion_time=0, + llm_calls=1, + prompt_tokens=0, + ), + SearchResult( + response=[{"answer": "Point 2", "score": 80}], + context_data="", + context_text="", + completion_time=0, + llm_calls=1, + prompt_tokens=0, + ), + ] + query = "Test query" + result = await global_search.reduce_response(map_responses, query) + + assert isinstance(result, SearchResult) + assert result.llm_calls == 1 + assert result.response == "Final answer"