Skip to content

Commit

Permalink
improved parsing json in global search
Browse files Browse the repository at this point in the history
- added system prompt to answer in JSON format
  • Loading branch information
jaigouk authored and tachkoma2501 committed Jul 16, 2024
1 parent 0aa1c31 commit bfe8d34
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 8 deletions.
45 changes: 37 additions & 8 deletions graphrag/query/structured_search/global_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
import json
import logging
import re
import time
from dataclasses import dataclass
from typing import Any
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(

self.map_llm_params = map_llm_params
self.reduce_llm_params = reduce_llm_params
self.json_mode = json_mode
if json_mode:
self.map_llm_params["response_format"] = {"type": "json_object"}
else:
Expand Down Expand Up @@ -174,6 +176,9 @@ async def _map_response_single_batch(
search_prompt = ""
try:
search_prompt = self.map_system_prompt.format(context_data=context_data)
if self.json_mode:
search_prompt += "\nYour response should be in JSON format."
self._last_search_prompt = search_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
Expand Down Expand Up @@ -228,15 +233,37 @@ def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
-------
list[dict[str, Any]]
A list of key points, each key point is a dictionary with "answer" and "score" keys
Raises
------
ValueError
If the response cannot be parsed as JSON or doesn't contain the expected structure
"""
parsed_elements = json.loads(search_response)["points"]
return [
{
"answer": element["description"],
"score": int(element["score"]),
}
for element in parsed_elements
]
# Try to extract JSON from the response if it's embedded in text
json_match = re.search(r"\{.*\}", search_response, re.DOTALL)
json_str = json_match.group() if json_match else search_response

try:
parsed_data = json.loads(json_str)
except json.JSONDecodeError as err:
error_msg = "Failed to parse response as JSON"
raise ValueError(error_msg) from err

if "points" not in parsed_data or not isinstance(parsed_data["points"], list):
error_msg = "Response JSON does not contain a 'points' list"
raise ValueError(error_msg)

try:
return [
{
"answer": element["description"],
"score": int(element["score"]),
}
for element in parsed_data["points"]
]
except (KeyError, ValueError) as e:
error_msg = f"Error processing points: {e!s}"
raise ValueError(error_msg) from e

async def _reduce_response(
self,
Expand Down Expand Up @@ -316,6 +343,8 @@ async def _reduce_response(
)
if self.allow_general_knowledge:
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
if self.json_mode:
search_prompt += "\nYour response should be in JSON format."
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/query/structured_search/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
218 changes: 218 additions & 0 deletions tests/unit/query/structured_search/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from unittest.mock import AsyncMock, MagicMock

import pytest

from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.llm.base import BaseLLM
from graphrag.query.structured_search.base import SearchResult
from graphrag.query.structured_search.global_search.search import GlobalSearch


class MockLLM(BaseLLM):
def __init__(self):
self.call_count = 0
self.last_messages = None

def generate(
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs
):
self.call_count += 1
self.last_messages = messages
return "mocked response"

async def agenerate(
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs
):
self.call_count += 1
self.last_messages = messages
return "mocked response"


class MockContextBuilder(GlobalContextBuilder):
def build_context(self, conversation_history=None, **kwargs):
return ["mocked context"], {}


class TestGlobalSearch(GlobalSearch):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_search_prompt = None

def get_last_search_prompt(self):
return self._last_search_prompt

def set_last_search_prompt(self, value):
self._last_search_prompt = value


@pytest.fixture
def global_search():
llm = MockLLM()
context_builder = MockContextBuilder()
return TestGlobalSearch(llm, context_builder)


@pytest.mark.asyncio
async def test_json_format_instruction_in_search_prompt(global_search):
global_search.json_mode = True
query = "Test query"

global_search.context_builder.build_context = MagicMock(
return_value=(["mocked context"], {})
)

global_search.parse_search_response = MagicMock(
return_value=[{"answer": "Mocked answer", "score": 100}]
)

await global_search.asearch(query)

assert global_search.get_last_search_prompt() is not None
assert (
"Your response should be in JSON format."
in global_search.get_last_search_prompt()
)

global_search.json_mode = False
global_search.set_last_search_prompt(None)
await global_search.asearch(query)

assert global_search.get_last_search_prompt() is not None
assert (
"Your response should be in JSON format."
not in global_search.get_last_search_prompt()
)


def test_parse_search_response_valid(global_search):
valid_response = """
{
"points": [
{"description": "Point 1", "score": 90},
{"description": "Point 2", "score": 80}
]
}
"""
result = global_search.parse_search_response(valid_response)
assert len(result) == 2
assert result[0] == {"answer": "Point 1", "score": 90}
assert result[1] == {"answer": "Point 2", "score": 80}


def test_parse_search_response_invalid_json(global_search):
invalid_json = "This is not JSON"
with pytest.raises(ValueError, match="Failed to parse response as JSON"):
global_search.parse_search_response(invalid_json)


def test_parse_search_response_missing_points(global_search):
missing_points = '{"data": "No points here"}'
with pytest.raises(
ValueError, match="Response JSON does not contain a 'points' list"
):
global_search.parse_search_response(missing_points)


def test_parse_search_response_invalid_point_format(global_search):
invalid_point = """
{
"points": [
{"wrong_key": "Point 1", "score": 90}
]
}
"""
with pytest.raises(ValueError, match="Error processing points"):
global_search.parse_search_response(invalid_point)


def test_parse_search_response_with_text_prefix(global_search):
response_with_prefix = """
Here's the response:
{
"points": [
{"description": "Point 1", "score": 90}
]
}
"""
result = global_search.parse_search_response(response_with_prefix)
assert len(result) == 1
assert result[0] == {"answer": "Point 1", "score": 90}


def test_parse_search_response_non_integer_score(global_search):
non_integer_score = """
{
"points": [
{"description": "Point 1", "score": "high"}
]
}
"""
with pytest.raises(ValueError, match="Error processing points"):
global_search.parse_search_response(non_integer_score)


@pytest.mark.asyncio
async def test_map_response(global_search):
global_search.map_response = AsyncMock(
return_value=[
SearchResult(
response=[{"answer": "Test answer", "score": 90}],
context_data="Test context",
context_text="Test context",
completion_time=0.1,
llm_calls=1,
prompt_tokens=10,
)
]
)

context_data = ["Test context"]
query = "Test query"
result = await global_search.map_response(context_data, query)

assert isinstance(result[0], SearchResult)
assert result[0].context_data == "Test context"
assert result[0].context_text == "Test context"
assert result[0].llm_calls == 1


@pytest.mark.asyncio
async def test_reduce_response(global_search):
global_search.reduce_response = AsyncMock(
return_value=SearchResult(
response="Final answer",
context_data="Combined context",
context_text="Combined context",
completion_time=0.2,
llm_calls=1,
prompt_tokens=20,
)
)

map_responses = [
SearchResult(
response=[{"answer": "Point 1", "score": 90}],
context_data="",
context_text="",
completion_time=0,
llm_calls=1,
prompt_tokens=0,
),
SearchResult(
response=[{"answer": "Point 2", "score": 80}],
context_data="",
context_text="",
completion_time=0,
llm_calls=1,
prompt_tokens=0,
),
]
query = "Test query"
result = await global_search.reduce_response(map_responses, query)

assert isinstance(result, SearchResult)
assert result.llm_calls == 1
assert result.response == "Final answer"

0 comments on commit bfe8d34

Please sign in to comment.