Skip to content

Commit

Permalink
added system prompt to answer in JSON format
Browse files Browse the repository at this point in the history
  • Loading branch information
jaigouk authored and tachkoma2501 committed Jul 16, 2024
1 parent 7869857 commit 1e42f57
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 51 deletions.
49 changes: 31 additions & 18 deletions graphrag/query/structured_search/global_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import asyncio
import json
import logging
import time
import re
import time
from dataclasses import dataclass
from typing import Any

Expand All @@ -16,12 +16,18 @@

from graphrag.index.utils.json import clean_up_json
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import ConversationHistory
from graphrag.query.context_builder.conversation_history import (
ConversationHistory,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.global_search.callbacks import GlobalSearchLLMCallback
from graphrag.query.structured_search.global_search.map_system_prompt import MAP_SYSTEM_PROMPT
from graphrag.query.structured_search.global_search.callbacks import (
GlobalSearchLLMCallback,
)
from graphrag.query.structured_search.global_search.map_system_prompt import (
MAP_SYSTEM_PROMPT,
)
from graphrag.query.structured_search.global_search.reduce_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
NO_DATA_ANSWER,
Expand All @@ -40,6 +46,7 @@

log = logging.getLogger(__name__)


@dataclass
class GlobalSearchResult(SearchResult):
"""A GlobalSearch result."""
Expand All @@ -48,6 +55,7 @@ class GlobalSearchResult(SearchResult):
reduce_context_data: str | list[pd.DataFrame] | dict[str, pd.DataFrame]
reduce_context_text: str | list[str] | dict[str, str]


class GlobalSearch(BaseSearch):
"""Search orchestration for global search mode."""

Expand Down Expand Up @@ -85,9 +93,11 @@ def __init__(

self.map_llm_params = map_llm_params
self.reduce_llm_params = reduce_llm_params
self.json_mode = json_mode
if json_mode:
self.map_llm_params["response_format"] = {"type": "json_object"}
else:
# remove response_format key if json_mode is False
self.map_llm_params.pop("response_format", None)

self.semaphore = asyncio.Semaphore(concurrent_coroutines)
Expand All @@ -114,7 +124,7 @@ async def asearch(

if self.callbacks:
for callback in self.callbacks:
callback.on_map_response_start(context_chunks) # type: ignore
callback.on_map_response_start(context_chunks) # type: ignore
map_responses = await asyncio.gather(*[
self._map_response_single_batch(
context_data=data, query=query, **self.map_llm_params
Expand Down Expand Up @@ -166,6 +176,9 @@ async def _map_response_single_batch(
search_prompt = ""
try:
search_prompt = self.map_system_prompt.format(context_data=context_data)
if self.json_mode:
search_prompt += "\nYour response should be in JSON format."
self._last_search_prompt = search_prompt
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
Expand Down Expand Up @@ -208,7 +221,6 @@ async def _map_response_single_batch(
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
)


def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
"""Parse the search response json and return a list of key points.
Expand All @@ -228,19 +240,18 @@ def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
If the response cannot be parsed as JSON or doesn't contain the expected structure
"""
# Try to extract JSON from the response if it's embedded in text
json_match = re.search(r'\{.*\}', search_response, re.DOTALL)
if json_match:
json_str = json_match.group()
else:
json_str = search_response
json_match = re.search(r"\{.*\}", search_response, re.DOTALL)
json_str = json_match.group() if json_match else search_response

try:
parsed_data = json.loads(json_str)
except json.JSONDecodeError:
raise ValueError("Failed to parse response as JSON")
except json.JSONDecodeError as err:
error_msg = "Failed to parse response as JSON"
raise ValueError(error_msg) from err

if 'points' not in parsed_data or not isinstance(parsed_data['points'], list):
raise ValueError("Response JSON does not contain a 'points' list")
if "points" not in parsed_data or not isinstance(parsed_data["points"], list):
error_msg = "Response JSON does not contain a 'points' list"
raise ValueError(error_msg)

try:
return [
Expand All @@ -251,8 +262,8 @@ def parse_search_response(self, search_response: str) -> list[dict[str, Any]]:
for element in parsed_data["points"]
]
except (KeyError, ValueError) as e:
raise ValueError(f"Error processing points: {str(e)}")

error_msg = f"Error processing points: {e!s}"
raise ValueError(error_msg) from e

async def _reduce_response(
self,
Expand Down Expand Up @@ -302,7 +313,7 @@ async def _reduce_response(
filtered_key_points = sorted(
filtered_key_points,
key=lambda x: x["score"], # type: ignore
reverse=True, # type: ignore
reverse=True, # type: ignore
)

data = []
Expand Down Expand Up @@ -332,6 +343,8 @@ async def _reduce_response(
)
if self.allow_general_knowledge:
search_prompt += "\n" + self.general_knowledge_inclusion_prompt
if self.json_mode:
search_prompt += "\nYour response should be in JSON format."
search_messages = [
{"role": "system", "content": search_prompt},
{"role": "user", "content": query},
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/query/structured_search/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
169 changes: 136 additions & 33 deletions tests/unit/query/structured_search/test_search.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,218 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

from unittest.mock import AsyncMock, MagicMock

import pytest
from graphrag.query.structured_search.global_search.search import GlobalSearch, GlobalSearchResult
from graphrag.query.structured_search.base import SearchResult
from graphrag.query.llm.base import BaseLLM

from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.llm.base import BaseLLM
from graphrag.query.structured_search.base import SearchResult
from graphrag.query.structured_search.global_search.search import GlobalSearch


class MockLLM(BaseLLM):
def __init__(self):
self.call_count = 0
self.last_messages = None

def generate(self, messages, streaming=False, **kwargs):
def generate(
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs
):
self.call_count += 1
self.last_messages = messages
return "mocked response"

async def agenerate(self, messages, streaming=False, **kwargs):
async def agenerate(
self, messages, streaming=False, callbacks=None, max_tokens=None, **kwargs
):
self.call_count += 1
self.last_messages = messages
return "mocked response"


class MockContextBuilder(GlobalContextBuilder):
def build_context(self, conversation_history=None, **kwargs):
return ["mocked context"], {}


class TestGlobalSearch(GlobalSearch):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._last_search_prompt = None

def get_last_search_prompt(self):
return self._last_search_prompt

def set_last_search_prompt(self, value):
self._last_search_prompt = value


@pytest.fixture
def global_search():
llm = MockLLM()
context_builder = MockContextBuilder()
return GlobalSearch(llm, context_builder)
return TestGlobalSearch(llm, context_builder)


@pytest.mark.asyncio
async def test_json_format_instruction_in_search_prompt(global_search):
global_search.json_mode = True
query = "Test query"

global_search.context_builder.build_context = MagicMock(
return_value=(["mocked context"], {})
)

global_search.parse_search_response = MagicMock(
return_value=[{"answer": "Mocked answer", "score": 100}]
)

await global_search.asearch(query)

assert global_search.get_last_search_prompt() is not None
assert (
"Your response should be in JSON format."
in global_search.get_last_search_prompt()
)

global_search.json_mode = False
global_search.set_last_search_prompt(None)
await global_search.asearch(query)

assert global_search.get_last_search_prompt() is not None
assert (
"Your response should be in JSON format."
not in global_search.get_last_search_prompt()
)


def test_parse_search_response_valid(global_search):
valid_response = '''
valid_response = """
{
"points": [
{"description": "Point 1", "score": 90},
{"description": "Point 2", "score": 80}
]
}
'''
"""
result = global_search.parse_search_response(valid_response)
assert len(result) == 2
assert result[0] == {"answer": "Point 1", "score": 90}
assert result[1] == {"answer": "Point 2", "score": 80}


def test_parse_search_response_invalid_json(global_search):
invalid_json = "This is not JSON"
with pytest.raises(ValueError, match="Failed to parse response as JSON"):
global_search.parse_search_response(invalid_json)


def test_parse_search_response_missing_points(global_search):
missing_points = '{"data": "No points here"}'
with pytest.raises(ValueError, match="Response JSON does not contain a 'points' list"):
with pytest.raises(
ValueError, match="Response JSON does not contain a 'points' list"
):
global_search.parse_search_response(missing_points)


def test_parse_search_response_invalid_point_format(global_search):
invalid_point = '''
invalid_point = """
{
"points": [
{"wrong_key": "Point 1", "score": 90}
]
}
'''
"""
with pytest.raises(ValueError, match="Error processing points"):
global_search.parse_search_response(invalid_point)


def test_parse_search_response_with_text_prefix(global_search):
response_with_prefix = '''
response_with_prefix = """
Here's the response:
{
"points": [
{"description": "Point 1", "score": 90}
]
}
'''
"""
result = global_search.parse_search_response(response_with_prefix)
assert len(result) == 1
assert result[0] == {"answer": "Point 1", "score": 90}


def test_parse_search_response_non_integer_score(global_search):
non_integer_score = '''
non_integer_score = """
{
"points": [
{"description": "Point 1", "score": "high"}
]
}
'''
"""
with pytest.raises(ValueError, match="Error processing points"):
global_search.parse_search_response(non_integer_score)


@pytest.mark.asyncio
async def test_map_response_single_batch(global_search):
context_data = "Test context"
async def test_map_response(global_search):
global_search.map_response = AsyncMock(
return_value=[
SearchResult(
response=[{"answer": "Test answer", "score": 90}],
context_data="Test context",
context_text="Test context",
completion_time=0.1,
llm_calls=1,
prompt_tokens=10,
)
]
)

context_data = ["Test context"]
query = "Test query"
result = await global_search._map_response_single_batch(context_data, query)
assert isinstance(result, SearchResult)
assert result.context_data == context_data
assert result.context_text == context_data
assert result.llm_calls == 1
result = await global_search.map_response(context_data, query)

assert isinstance(result[0], SearchResult)
assert result[0].context_data == "Test context"
assert result[0].context_text == "Test context"
assert result[0].llm_calls == 1


@pytest.mark.asyncio
async def test_reduce_response(global_search):
global_search.reduce_response = AsyncMock(
return_value=SearchResult(
response="Final answer",
context_data="Combined context",
context_text="Combined context",
completion_time=0.2,
llm_calls=1,
prompt_tokens=20,
)
)

map_responses = [
SearchResult(response=[{"answer": "Point 1", "score": 90}], context_data="", context_text="", completion_time=0, llm_calls=1, prompt_tokens=0),
SearchResult(response=[{"answer": "Point 2", "score": 80}], context_data="", context_text="", completion_time=0, llm_calls=1, prompt_tokens=0),
SearchResult(
response=[{"answer": "Point 1", "score": 90}],
context_data="",
context_text="",
completion_time=0,
llm_calls=1,
prompt_tokens=0,
),
SearchResult(
response=[{"answer": "Point 2", "score": 80}],
context_data="",
context_text="",
completion_time=0,
llm_calls=1,
prompt_tokens=0,
),
]
query = "Test query"
result = await global_search._reduce_response(map_responses, query)
result = await global_search.reduce_response(map_responses, query)

assert isinstance(result, SearchResult)
assert result.llm_calls == 1

@pytest.mark.asyncio
async def test_asearch(global_search):
query = "Test query"
result = await global_search.asearch(query)
assert isinstance(result, GlobalSearchResult)
assert result.llm_calls > 0
assert global_search.llm.call_count > 0 # Access the mock LLM through the GlobalSearch instance
assert result.response == "Final answer"

0 comments on commit 1e42f57

Please sign in to comment.