-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added system prompt to answer in JSON format
- Loading branch information
1 parent
7869857
commit 1e42f57
Showing
3 changed files
with
169 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,115 +1,218 @@ | ||
# Copyright (c) 2024 Microsoft Corporation. | ||
# Licensed under the MIT License | ||
|
||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
import pytest | ||
from graphrag.query.structured_search.global_search.search import GlobalSearch, GlobalSearchResult | ||
from graphrag.query.structured_search.base import SearchResult | ||
from graphrag.query.llm.base import BaseLLM | ||
|
||
from graphrag.query.context_builder.builders import GlobalContextBuilder | ||
from graphrag.query.llm.base import BaseLLM | ||
from graphrag.query.structured_search.base import SearchResult | ||
from graphrag.query.structured_search.global_search.search import GlobalSearch | ||
|
||
|
||
class MockLLM(BaseLLM): | ||
def __init__(self): | ||
self.call_count = 0 | ||
self.last_messages = None | ||
|
||
def generate(self, messages, streaming=False, **kwargs): | ||
def generate( | ||
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs | ||
): | ||
self.call_count += 1 | ||
self.last_messages = messages | ||
return "mocked response" | ||
|
||
async def agenerate(self, messages, streaming=False, **kwargs): | ||
async def agenerate( | ||
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs | ||
): | ||
self.call_count += 1 | ||
self.last_messages = messages | ||
return "mocked response" | ||
|
||
|
||
class MockContextBuilder(GlobalContextBuilder): | ||
def build_context(self, conversation_history=None, **kwargs): | ||
return ["mocked context"], {} | ||
|
||
|
||
class TestGlobalSearch(GlobalSearch): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self._last_search_prompt = None | ||
|
||
def get_last_search_prompt(self): | ||
return self._last_search_prompt | ||
|
||
def set_last_search_prompt(self, value): | ||
self._last_search_prompt = value | ||
|
||
|
||
@pytest.fixture | ||
def global_search(): | ||
llm = MockLLM() | ||
context_builder = MockContextBuilder() | ||
return GlobalSearch(llm, context_builder) | ||
return TestGlobalSearch(llm, context_builder) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_json_format_instruction_in_search_prompt(global_search): | ||
global_search.json_mode = True | ||
query = "Test query" | ||
|
||
global_search.context_builder.build_context = MagicMock( | ||
return_value=(["mocked context"], {}) | ||
) | ||
|
||
global_search.parse_search_response = MagicMock( | ||
return_value=[{"answer": "Mocked answer", "score": 100}] | ||
) | ||
|
||
await global_search.asearch(query) | ||
|
||
assert global_search.get_last_search_prompt() is not None | ||
assert ( | ||
"Your response should be in JSON format." | ||
in global_search.get_last_search_prompt() | ||
) | ||
|
||
global_search.json_mode = False | ||
global_search.set_last_search_prompt(None) | ||
await global_search.asearch(query) | ||
|
||
assert global_search.get_last_search_prompt() is not None | ||
assert ( | ||
"Your response should be in JSON format." | ||
not in global_search.get_last_search_prompt() | ||
) | ||
|
||
|
||
def test_parse_search_response_valid(global_search): | ||
valid_response = ''' | ||
valid_response = """ | ||
{ | ||
"points": [ | ||
{"description": "Point 1", "score": 90}, | ||
{"description": "Point 2", "score": 80} | ||
] | ||
} | ||
''' | ||
""" | ||
result = global_search.parse_search_response(valid_response) | ||
assert len(result) == 2 | ||
assert result[0] == {"answer": "Point 1", "score": 90} | ||
assert result[1] == {"answer": "Point 2", "score": 80} | ||
|
||
|
||
def test_parse_search_response_invalid_json(global_search): | ||
invalid_json = "This is not JSON" | ||
with pytest.raises(ValueError, match="Failed to parse response as JSON"): | ||
global_search.parse_search_response(invalid_json) | ||
|
||
|
||
def test_parse_search_response_missing_points(global_search): | ||
missing_points = '{"data": "No points here"}' | ||
with pytest.raises(ValueError, match="Response JSON does not contain a 'points' list"): | ||
with pytest.raises( | ||
ValueError, match="Response JSON does not contain a 'points' list" | ||
): | ||
global_search.parse_search_response(missing_points) | ||
|
||
|
||
def test_parse_search_response_invalid_point_format(global_search): | ||
invalid_point = ''' | ||
invalid_point = """ | ||
{ | ||
"points": [ | ||
{"wrong_key": "Point 1", "score": 90} | ||
] | ||
} | ||
''' | ||
""" | ||
with pytest.raises(ValueError, match="Error processing points"): | ||
global_search.parse_search_response(invalid_point) | ||
|
||
|
||
def test_parse_search_response_with_text_prefix(global_search): | ||
response_with_prefix = ''' | ||
response_with_prefix = """ | ||
Here's the response: | ||
{ | ||
"points": [ | ||
{"description": "Point 1", "score": 90} | ||
] | ||
} | ||
''' | ||
""" | ||
result = global_search.parse_search_response(response_with_prefix) | ||
assert len(result) == 1 | ||
assert result[0] == {"answer": "Point 1", "score": 90} | ||
|
||
|
||
def test_parse_search_response_non_integer_score(global_search): | ||
non_integer_score = ''' | ||
non_integer_score = """ | ||
{ | ||
"points": [ | ||
{"description": "Point 1", "score": "high"} | ||
] | ||
} | ||
''' | ||
""" | ||
with pytest.raises(ValueError, match="Error processing points"): | ||
global_search.parse_search_response(non_integer_score) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_map_response_single_batch(global_search): | ||
context_data = "Test context" | ||
async def test_map_response(global_search): | ||
global_search.map_response = AsyncMock( | ||
return_value=[ | ||
SearchResult( | ||
response=[{"answer": "Test answer", "score": 90}], | ||
context_data="Test context", | ||
context_text="Test context", | ||
completion_time=0.1, | ||
llm_calls=1, | ||
prompt_tokens=10, | ||
) | ||
] | ||
) | ||
|
||
context_data = ["Test context"] | ||
query = "Test query" | ||
result = await global_search._map_response_single_batch(context_data, query) | ||
assert isinstance(result, SearchResult) | ||
assert result.context_data == context_data | ||
assert result.context_text == context_data | ||
assert result.llm_calls == 1 | ||
result = await global_search.map_response(context_data, query) | ||
|
||
assert isinstance(result[0], SearchResult) | ||
assert result[0].context_data == "Test context" | ||
assert result[0].context_text == "Test context" | ||
assert result[0].llm_calls == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_reduce_response(global_search): | ||
global_search.reduce_response = AsyncMock( | ||
return_value=SearchResult( | ||
response="Final answer", | ||
context_data="Combined context", | ||
context_text="Combined context", | ||
completion_time=0.2, | ||
llm_calls=1, | ||
prompt_tokens=20, | ||
) | ||
) | ||
|
||
map_responses = [ | ||
SearchResult(response=[{"answer": "Point 1", "score": 90}], context_data="", context_text="", completion_time=0, llm_calls=1, prompt_tokens=0), | ||
SearchResult(response=[{"answer": "Point 2", "score": 80}], context_data="", context_text="", completion_time=0, llm_calls=1, prompt_tokens=0), | ||
SearchResult( | ||
response=[{"answer": "Point 1", "score": 90}], | ||
context_data="", | ||
context_text="", | ||
completion_time=0, | ||
llm_calls=1, | ||
prompt_tokens=0, | ||
), | ||
SearchResult( | ||
response=[{"answer": "Point 2", "score": 80}], | ||
context_data="", | ||
context_text="", | ||
completion_time=0, | ||
llm_calls=1, | ||
prompt_tokens=0, | ||
), | ||
] | ||
query = "Test query" | ||
result = await global_search._reduce_response(map_responses, query) | ||
result = await global_search.reduce_response(map_responses, query) | ||
|
||
assert isinstance(result, SearchResult) | ||
assert result.llm_calls == 1 | ||
|
||
@pytest.mark.asyncio | ||
async def test_asearch(global_search): | ||
query = "Test query" | ||
result = await global_search.asearch(query) | ||
assert isinstance(result, GlobalSearchResult) | ||
assert result.llm_calls > 0 | ||
assert global_search.llm.call_count > 0 # Access the mock LLM through the GlobalSearch instance | ||
assert result.response == "Final answer" |