Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move prompts #1404

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241114012244853718.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Centralized prompts and export all for easier injection."
}
8 changes: 4 additions & 4 deletions docs/prompt_tuning/manual_prompt_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain

## Entity/Relationship Extraction

[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/graph/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/entity_extraction.py)

### Tokens (values provided by extractor)

Expand All @@ -20,7 +20,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain

## Summarize Entity/Relationship Descriptions

[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/summarize_descriptions.py)

### Tokens (values provided by extractor)

Expand All @@ -29,7 +29,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain

## Claim Extraction

[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/claims/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/claim_extraction.py)

### Tokens (values provided by extractor)

Expand All @@ -47,7 +47,7 @@ See the [configuration documentation](../config/overview.md) for details on how

## Generate Community Reports

[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/community_reports/prompts.py)
[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/community_report.py)

### Tokens (values provided by extractor)

Expand Down
6 changes: 3 additions & 3 deletions docs/query/global_search.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ Below are the key parameters of the [GlobalSearch class](https://github.com/micr

* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/community_context.py) object to be used for preparing context data from community reports
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/map_system_prompt.py)
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_map_system_prompt.py)
* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_reduce_system_prompt.py)
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
* `allow_general_knowledge`: setting this to True will include additional instructions to the `reduce_system_prompt` to prompt the LLM to incorporate relevant real-world knowledge outside of the dataset. Note that this may increase hallucinations, but can be useful for certain scenarios. Default is False
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py)
*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_knowledge_system_prompt.py)
* `max_data_tokens`: token budget for the context data
* `map_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call at the `map` stage
* `reduce_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to passed to the LLM call at the `reduce` stage
Expand Down
2 changes: 1 addition & 1 deletion docs/query/local_search.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Below are the key parameters of the [LocalSearch class](https://github.com/micro

* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/system_prompt.py)
* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/local_search_system_prompt.py)
* `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`)
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the search prompt
Expand Down
2 changes: 1 addition & 1 deletion docs/query/question_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Below are the key parameters of the [Question Generation class](https://github.c

* `llm`: OpenAI model object to be used for response generation
* `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects, using the same context builder class as in local search
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/question_gen/system_prompt.py)
* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/question_gen_system_prompt.py)
* `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call
* `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the question generation prompt
* `callbacks`: optional callback functions, can be used to provide custom event handlers for LLM's completion streaming events
Expand Down
43 changes: 42 additions & 1 deletion graphrag/api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,24 @@ async def global_search(
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = _load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)

search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
result: SearchResult = await search_engine.asearch(query=query)
response = result.response
Expand Down Expand Up @@ -156,13 +167,24 @@ async def global_search_streaming(
dynamic_community_selection=dynamic_community_selection,
)
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
reduce_prompt = _load_search_prompt(
config.root_dir, config.global_search.reduce_prompt
)
knowledge_prompt = _load_search_prompt(
config.root_dir, config.global_search.knowledge_prompt
)

search_engine = get_global_search_engine(
config,
reports=reports,
entities=_entities,
communities=_communities,
response_type=response_type,
dynamic_community_selection=dynamic_community_selection,
map_system_prompt=map_prompt,
reduce_system_prompt=reduce_prompt,
general_knowledge_inclusion_prompt=knowledge_prompt,
)
search_result = search_engine.astream_search(query=query)

Expand Down Expand Up @@ -238,6 +260,7 @@ async def local_search(

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)

search_engine = get_local_search_engine(
config=config,
Expand All @@ -248,6 +271,7 @@ async def local_search(
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
)

result: SearchResult = await search_engine.asearch(query=query)
Expand Down Expand Up @@ -312,6 +336,7 @@ async def local_search_streaming(

_entities = read_indexer_entities(nodes, entities, community_level)
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)

search_engine = get_local_search_engine(
config=config,
Expand All @@ -322,6 +347,7 @@ async def local_search_streaming(
covariates={"claims": _covariates},
description_embedding_store=description_embedding_store, # type: ignore
response_type=response_type,
system_prompt=prompt,
)
search_result = search_engine.astream_search(query=query)

Expand Down Expand Up @@ -401,14 +427,15 @@ async def drift_search(
_entities = read_indexer_entities(nodes, entities, community_level)
_reports = read_indexer_reports(community_reports, nodes, community_level)
read_indexer_report_embeddings(_reports, full_content_embedding_store)

prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
search_engine = get_drift_search_engine(
config=config,
reports=_reports,
text_units=read_indexer_text_units(text_units),
entities=_entities,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
local_system_prompt=prompt,
)

result: SearchResult = await search_engine.asearch(query=query)
Expand Down Expand Up @@ -551,3 +578,17 @@ def _reformat_context_data(context_data: dict) -> dict:
continue
final_format[key] = records
return final_format


def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
"""
Load the search prompt from disk if configured.

If not, leave it empty - the search functions will load their defaults.

"""
if prompt_config:
prompt_file = Path(root_dir) / prompt_config
if prompt_file.exists():
return prompt_file.read_bytes().decode(encoding="utf-8")
return None
63 changes: 33 additions & 30 deletions graphrag/cli/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@

from pathlib import Path

from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
from graphrag.index.graph.extractors.community_reports.prompts import (
COMMUNITY_REPORT_PROMPT,
)
from graphrag.index.graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
from graphrag.index.graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
from graphrag.index.init_content import INIT_DOTENV, INIT_YAML
from graphrag.logging import ReporterType, create_progress_reporter
from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT
from graphrag.prompts.index.community_report import (
COMMUNITY_REPORT_PROMPT,
)
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
GENERAL_KNOWLEDGE_INSTRUCTION,
)
from graphrag.prompts.query.global_search_map_system_prompt import MAP_SYSTEM_PROMPT
from graphrag.prompts.query.global_search_reduce_system_prompt import (
REDUCE_SYSTEM_PROMPT,
)
from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT
from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT


def initialize_project_at(path: Path) -> None:
Expand Down Expand Up @@ -40,28 +50,21 @@ def initialize_project_at(path: Path) -> None:
if not prompts_dir.exists():
prompts_dir.mkdir(parents=True, exist_ok=True)

entity_extraction = prompts_dir / "entity_extraction.txt"
if not entity_extraction.exists():
with entity_extraction.open("wb") as file:
file.write(
GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)

summarize_descriptions = prompts_dir / "summarize_descriptions.txt"
if not summarize_descriptions.exists():
with summarize_descriptions.open("wb") as file:
file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict"))

claim_extraction = prompts_dir / "claim_extraction.txt"
if not claim_extraction.exists():
with claim_extraction.open("wb") as file:
file.write(
CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict")
)
prompts = {
"entity_extraction": GRAPH_EXTRACTION_PROMPT,
"summarize_descriptions": SUMMARIZE_PROMPT,
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
"community_report": COMMUNITY_REPORT_PROMPT,
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,
"local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT,
"question_gen_system_prompt": QUESTION_SYSTEM_PROMPT,
}

community_report = prompts_dir / "community_report.txt"
if not community_report.exists():
with community_report.open("wb") as file:
file.write(
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
)
for name, content in prompts.items():
prompt_file = prompts_dir / f"{name}.txt"
if not prompt_file.exists():
with prompt_file.open("wb") as file:
file.write(content.encode(encoding="utf-8", errors="strict"))
2 changes: 2 additions & 0 deletions graphrag/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ClaimExtractionConfig,
ClusterGraphConfig,
CommunityReportsConfig,
DRIFTSearchConfig,
EmbedGraphConfig,
EntityExtractionConfig,
GlobalSearchConfig,
Expand Down Expand Up @@ -85,6 +86,7 @@
"ClusterGraphConfigInput",
"CommunityReportsConfig",
"CommunityReportsConfigInput",
"DRIFTSearchConfig",
"EmbedGraphConfig",
"EmbedGraphConfigInput",
"EntityExtractionConfig",
Expand Down
Loading
Loading