diff --git a/.semversioner/next-release/patch-20241114012244853718.json b/.semversioner/next-release/patch-20241114012244853718.json new file mode 100644 index 0000000000..9b16cc63c8 --- /dev/null +++ b/.semversioner/next-release/patch-20241114012244853718.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Centralized prompts and export all for easier injection." +} diff --git a/docs/prompt_tuning/manual_prompt_tuning.md b/docs/prompt_tuning/manual_prompt_tuning.md index 4618c28745..7f10fc8e79 100644 --- a/docs/prompt_tuning/manual_prompt_tuning.md +++ b/docs/prompt_tuning/manual_prompt_tuning.md @@ -8,7 +8,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain ## Entity/Relationship Extraction -[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/graph/prompts.py) +[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/entity_extraction.py) ### Tokens (values provided by extractor) @@ -20,7 +20,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain ## Summarize Entity/Relationship Descriptions -[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py) +[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/summarize_descriptions.py) ### Tokens (values provided by extractor) @@ -29,7 +29,7 @@ Each of these prompts may be overridden by writing a custom prompt file in plain ## Claim Extraction -[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/claims/prompts.py) +[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/claim_extraction.py) ### Tokens (values provided by extractor) @@ -47,7 +47,7 @@ See the [configuration documentation](../config/overview.md) for details on how ## Generate Community Reports -[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/community_reports/prompts.py) +[Prompt Source](http://github.com/microsoft/graphrag/blob/main/graphrag/prompts/community_report.py) ### Tokens (values provided by extractor) diff --git a/docs/query/global_search.md b/docs/query/global_search.md index b7b6bfc8d2..a9685d70be 100644 --- a/docs/query/global_search.md +++ b/docs/query/global_search.md @@ -56,11 +56,11 @@ Below are the key parameters of the [GlobalSearch class](https://github.com/micr * `llm`: OpenAI model object to be used for response generation * `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/community_context.py) object to be used for preparing context data from community reports -* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/map_system_prompt.py) -* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py) +* `map_system_prompt`: prompt template used in the `map` stage. Default template can be found at [map_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_map_system_prompt.py) +* `reduce_system_prompt`: prompt template used in the `reduce` stage, default template can be found at [reduce_system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_reduce_system_prompt.py) * `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`) * `allow_general_knowledge`: setting this to True will include additional instructions to the `reduce_system_prompt` to prompt the LLM to incorporate relevant real-world knowledge outside of the dataset. Note that this may increase hallucinations, but can be useful for certain scenarios. Default is False -*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/global_search/reduce_system_prompt.py) +*`general_knowledge_inclusion_prompt`: instruction to add to the `reduce_system_prompt` if `allow_general_knowledge` is enabled. Default instruction can be found at [general_knowledge_instruction](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/global_search_knowledge_system_prompt.py) * `max_data_tokens`: token budget for the context data * `map_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call at the `map` stage * `reduce_llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to passed to the LLM call at the `reduce` stage diff --git a/docs/query/local_search.md b/docs/query/local_search.md index 7169392cc2..bf0f43e3ce 100644 --- a/docs/query/local_search.md +++ b/docs/query/local_search.md @@ -50,7 +50,7 @@ Below are the key parameters of the [LocalSearch class](https://github.com/micro * `llm`: OpenAI model object to be used for response generation * `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects -* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/system_prompt.py) +* `system_prompt`: prompt template used to generate the search response. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/local_search_system_prompt.py) * `response_type`: free-form text describing the desired response type and format (e.g., `Multiple Paragraphs`, `Multi-Page Report`) * `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call * `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the search prompt diff --git a/docs/query/question_generation.md b/docs/query/question_generation.md index 43643e3727..525a465499 100644 --- a/docs/query/question_generation.md +++ b/docs/query/question_generation.md @@ -13,7 +13,7 @@ Below are the key parameters of the [Question Generation class](https://github.c * `llm`: OpenAI model object to be used for response generation * `context_builder`: [context builder](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object to be used for preparing context data from collections of knowledge model objects, using the same context builder class as in local search -* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/query/question_gen/system_prompt.py) +* `system_prompt`: prompt template used to generate candidate questions. Default template can be found at [system_prompt](https://github.com/microsoft/graphrag/blob/main//graphrag/prompts/query/question_gen_system_prompt.py) * `llm_params`: a dictionary of additional parameters (e.g., temperature, max_tokens) to be passed to the LLM call * `context_builder_params`: a dictionary of additional parameters to be passed to the [`context_builder`](https://github.com/microsoft/graphrag/blob/main//graphrag/query/structured_search/local_search/mixed_context.py) object when building context for the question generation prompt * `callbacks`: optional callback functions, can be used to provide custom event handlers for LLM's completion streaming events diff --git a/graphrag/api/query.py b/graphrag/api/query.py index 4c7941b5f1..21648a12a8 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -98,6 +98,14 @@ async def global_search( dynamic_community_selection=dynamic_community_selection, ) _entities = read_indexer_entities(nodes, entities, community_level=community_level) + map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt) + reduce_prompt = _load_search_prompt( + config.root_dir, config.global_search.reduce_prompt + ) + knowledge_prompt = _load_search_prompt( + config.root_dir, config.global_search.knowledge_prompt + ) + search_engine = get_global_search_engine( config, reports=reports, @@ -105,6 +113,9 @@ async def global_search( communities=_communities, response_type=response_type, dynamic_community_selection=dynamic_community_selection, + map_system_prompt=map_prompt, + reduce_system_prompt=reduce_prompt, + general_knowledge_inclusion_prompt=knowledge_prompt, ) result: SearchResult = await search_engine.asearch(query=query) response = result.response @@ -156,6 +167,14 @@ async def global_search_streaming( dynamic_community_selection=dynamic_community_selection, ) _entities = read_indexer_entities(nodes, entities, community_level=community_level) + map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt) + reduce_prompt = _load_search_prompt( + config.root_dir, config.global_search.reduce_prompt + ) + knowledge_prompt = _load_search_prompt( + config.root_dir, config.global_search.knowledge_prompt + ) + search_engine = get_global_search_engine( config, reports=reports, @@ -163,6 +182,9 @@ async def global_search_streaming( communities=_communities, response_type=response_type, dynamic_community_selection=dynamic_community_selection, + map_system_prompt=map_prompt, + reduce_system_prompt=reduce_prompt, + general_knowledge_inclusion_prompt=knowledge_prompt, ) search_result = search_engine.astream_search(query=query) @@ -238,6 +260,7 @@ async def local_search( _entities = read_indexer_entities(nodes, entities, community_level) _covariates = read_indexer_covariates(covariates) if covariates is not None else [] + prompt = _load_search_prompt(config.root_dir, config.local_search.prompt) search_engine = get_local_search_engine( config=config, @@ -248,6 +271,7 @@ async def local_search( covariates={"claims": _covariates}, description_embedding_store=description_embedding_store, # type: ignore response_type=response_type, + system_prompt=prompt, ) result: SearchResult = await search_engine.asearch(query=query) @@ -312,6 +336,7 @@ async def local_search_streaming( _entities = read_indexer_entities(nodes, entities, community_level) _covariates = read_indexer_covariates(covariates) if covariates is not None else [] + prompt = _load_search_prompt(config.root_dir, config.local_search.prompt) search_engine = get_local_search_engine( config=config, @@ -322,6 +347,7 @@ async def local_search_streaming( covariates={"claims": _covariates}, description_embedding_store=description_embedding_store, # type: ignore response_type=response_type, + system_prompt=prompt, ) search_result = search_engine.astream_search(query=query) @@ -401,7 +427,7 @@ async def drift_search( _entities = read_indexer_entities(nodes, entities, community_level) _reports = read_indexer_reports(community_reports, nodes, community_level) read_indexer_report_embeddings(_reports, full_content_embedding_store) - + prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt) search_engine = get_drift_search_engine( config=config, reports=_reports, @@ -409,6 +435,7 @@ async def drift_search( entities=_entities, relationships=read_indexer_relationships(relationships), description_embedding_store=description_embedding_store, # type: ignore + local_system_prompt=prompt, ) result: SearchResult = await search_engine.asearch(query=query) @@ -551,3 +578,17 @@ def _reformat_context_data(context_data: dict) -> dict: continue final_format[key] = records return final_format + + +def _load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: + """ + Load the search prompt from disk if configured. + + If not, leave it empty - the search functions will load their defaults. + + """ + if prompt_config: + prompt_file = Path(root_dir) / prompt_config + if prompt_file.exists(): + return prompt_file.read_bytes().decode(encoding="utf-8") + return None diff --git a/graphrag/cli/initialize.py b/graphrag/cli/initialize.py index b861132597..c807397799 100644 --- a/graphrag/cli/initialize.py +++ b/graphrag/cli/initialize.py @@ -5,14 +5,24 @@ from pathlib import Path -from graphrag.index.graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT -from graphrag.index.graph.extractors.community_reports.prompts import ( - COMMUNITY_REPORT_PROMPT, -) -from graphrag.index.graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT -from graphrag.index.graph.extractors.summarize.prompts import SUMMARIZE_PROMPT from graphrag.index.init_content import INIT_DOTENV, INIT_YAML from graphrag.logging import ReporterType, create_progress_reporter +from graphrag.prompts.index.claim_extraction import CLAIM_EXTRACTION_PROMPT +from graphrag.prompts.index.community_report import ( + COMMUNITY_REPORT_PROMPT, +) +from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT +from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT +from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT +from graphrag.prompts.query.global_search_knowledge_system_prompt import ( + GENERAL_KNOWLEDGE_INSTRUCTION, +) +from graphrag.prompts.query.global_search_map_system_prompt import MAP_SYSTEM_PROMPT +from graphrag.prompts.query.global_search_reduce_system_prompt import ( + REDUCE_SYSTEM_PROMPT, +) +from graphrag.prompts.query.local_search_system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT +from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT def initialize_project_at(path: Path) -> None: @@ -40,28 +50,21 @@ def initialize_project_at(path: Path) -> None: if not prompts_dir.exists(): prompts_dir.mkdir(parents=True, exist_ok=True) - entity_extraction = prompts_dir / "entity_extraction.txt" - if not entity_extraction.exists(): - with entity_extraction.open("wb") as file: - file.write( - GRAPH_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") - ) - - summarize_descriptions = prompts_dir / "summarize_descriptions.txt" - if not summarize_descriptions.exists(): - with summarize_descriptions.open("wb") as file: - file.write(SUMMARIZE_PROMPT.encode(encoding="utf-8", errors="strict")) - - claim_extraction = prompts_dir / "claim_extraction.txt" - if not claim_extraction.exists(): - with claim_extraction.open("wb") as file: - file.write( - CLAIM_EXTRACTION_PROMPT.encode(encoding="utf-8", errors="strict") - ) + prompts = { + "entity_extraction": GRAPH_EXTRACTION_PROMPT, + "summarize_descriptions": SUMMARIZE_PROMPT, + "claim_extraction": CLAIM_EXTRACTION_PROMPT, + "community_report": COMMUNITY_REPORT_PROMPT, + "drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT, + "global_search_map_system_prompt": MAP_SYSTEM_PROMPT, + "global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT, + "global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION, + "local_search_system_prompt": LOCAL_SEARCH_SYSTEM_PROMPT, + "question_gen_system_prompt": QUESTION_SYSTEM_PROMPT, + } - community_report = prompts_dir / "community_report.txt" - if not community_report.exists(): - with community_report.open("wb") as file: - file.write( - COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict") - ) + for name, content in prompts.items(): + prompt_file = prompts_dir / f"{name}.txt" + if not prompt_file.exists(): + with prompt_file.open("wb") as file: + file.write(content.encode(encoding="utf-8", errors="strict")) diff --git a/graphrag/config/__init__.py b/graphrag/config/__init__.py index 62708f5d9e..3354f2ccd8 100644 --- a/graphrag/config/__init__.py +++ b/graphrag/config/__init__.py @@ -51,6 +51,7 @@ ClaimExtractionConfig, ClusterGraphConfig, CommunityReportsConfig, + DRIFTSearchConfig, EmbedGraphConfig, EntityExtractionConfig, GlobalSearchConfig, @@ -85,6 +86,7 @@ "ClusterGraphConfigInput", "CommunityReportsConfig", "CommunityReportsConfigInput", + "DRIFTSearchConfig", "EmbedGraphConfig", "EmbedGraphConfigInput", "EntityExtractionConfig", diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index a750ae840e..42b250088e 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -39,6 +39,7 @@ ClaimExtractionConfig, ClusterGraphConfig, CommunityReportsConfig, + DRIFTSearchConfig, EmbedGraphConfig, EntityExtractionConfig, GlobalSearchConfig, @@ -514,6 +515,7 @@ def hydrate_parallelization_params( reader.envvar_prefix(Section.local_search), ): local_search_model = LocalSearchConfig( + prompt=reader.str("prompt") or None, text_unit_prop=reader.float("text_unit_prop") or defs.LOCAL_SEARCH_TEXT_UNIT_PROP, community_prop=reader.float("community_prop") @@ -541,6 +543,9 @@ def hydrate_parallelization_params( reader.envvar_prefix(Section.global_search), ): global_search_model = GlobalSearchConfig( + map_prompt=reader.str("map_prompt") or None, + reduce_prompt=reader.str("reduce_prompt") or None, + knowledge_prompt=reader.str("knowledge_prompt") or None, temperature=reader.float("llm_temperature") or defs.GLOBAL_SEARCH_LLM_TEMPERATURE, top_p=reader.float("llm_top_p") or defs.GLOBAL_SEARCH_LLM_TOP_P, @@ -556,6 +561,54 @@ def hydrate_parallelization_params( concurrency=reader.int("concurrency") or defs.GLOBAL_SEARCH_CONCURRENCY, ) + with ( + reader.use(values.get("drift_search")), + reader.envvar_prefix(Section.drift_search), + ): + drift_search_model = DRIFTSearchConfig( + prompt=reader.str("prompt") or None, + temperature=reader.float("llm_temperature") + or defs.DRIFT_SEARCH_LLM_TEMPERATURE, + top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P, + n=reader.int("llm_n") or defs.DRIFT_SEARCH_LLM_N, + max_tokens=reader.int(Fragment.max_tokens) + or defs.DRIFT_SEARCH_MAX_TOKENS, + data_max_tokens=reader.int("data_max_tokens") + or defs.DRIFT_SEARCH_DATA_MAX_TOKENS, + concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY, + drift_k_followups=reader.int("drift_k_followups") + or defs.DRIFT_SEARCH_K_FOLLOW_UPS, + primer_folds=reader.int("primer_folds") + or defs.DRIFT_SEARCH_PRIMER_FOLDS, + primer_llm_max_tokens=reader.int("primer_llm_max_tokens") + or defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS, + n_depth=reader.int("n_depth") or defs.DRIFT_N_DEPTH, + local_search_text_unit_prop=reader.float("local_search_text_unit_prop") + or defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP, + local_search_community_prop=reader.float("local_search_community_prop") + or defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP, + local_search_top_k_mapped_entities=reader.int( + "local_search_top_k_mapped_entities" + ) + or defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES, + local_search_top_k_relationships=reader.int( + "local_search_top_k_relationships" + ) + or defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS, + local_search_max_data_tokens=reader.int("local_search_max_data_tokens") + or defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS, + local_search_temperature=reader.float("local_search_temperature") + or defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE, + local_search_top_p=reader.float("local_search_top_p") + or defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P, + local_search_n=reader.int("local_search_n") + or defs.DRIFT_LOCAL_SEARCH_LLM_N, + local_search_llm_max_gen_tokens=reader.int( + "local_search_llm_max_gen_tokens" + ) + or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS, + ) + encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL skip_workflows = reader.list("skip_workflows") or [] @@ -583,6 +636,7 @@ def hydrate_parallelization_params( skip_workflows=skip_workflows, local_search=local_search_model, global_search=global_search_model, + drift_search=drift_search_model, ) @@ -649,6 +703,7 @@ class Section(str, Enum): update_index_storage = "UPDATE_INDEX_STORAGE" local_search = "LOCAL_SEARCH" global_search = "GLOBAL_SEARCH" + drift_search = "DRIFT_SEARCH" def _is_azure(llm_type: LLMType | None) -> bool: diff --git a/graphrag/config/models/drift_config.py b/graphrag/config/models/drift_config.py index 03b80150a2..7860f60ae6 100644 --- a/graphrag/config/models/drift_config.py +++ b/graphrag/config/models/drift_config.py @@ -11,6 +11,9 @@ class DRIFTSearchConfig(BaseModel): """The default configuration section for Cache.""" + prompt: str | None = Field( + description="The drift search prompt to use.", default=None + ) temperature: float = Field( description="The temperature to use for token generation.", default=defs.DRIFT_SEARCH_LLM_TEMPERATURE, diff --git a/graphrag/config/models/global_search_config.py b/graphrag/config/models/global_search_config.py index b528d9a0e0..7f2573d402 100644 --- a/graphrag/config/models/global_search_config.py +++ b/graphrag/config/models/global_search_config.py @@ -11,6 +11,15 @@ class GlobalSearchConfig(BaseModel): """The default configuration section for Cache.""" + map_prompt: str | None = Field( + description="The global search mapper prompt to use.", default=None + ) + reduce_prompt: str | None = Field( + description="The global search reducer to use.", default=None + ) + knowledge_prompt: str | None = Field( + description="The global search general prompt to use.", default=None + ) temperature: float | None = Field( description="The temperature to use for token generation.", default=defs.GLOBAL_SEARCH_LLM_TEMPERATURE, diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 4eaf8c284a..35912e12f5 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -13,6 +13,7 @@ from .claim_extraction_config import ClaimExtractionConfig from .cluster_graph_config import ClusterGraphConfig from .community_reports_config import CommunityReportsConfig +from .drift_config import DRIFTSearchConfig from .embed_graph_config import EmbedGraphConfig from .entity_extraction_config import EntityExtractionConfig from .global_search_config import GlobalSearchConfig @@ -141,6 +142,11 @@ def __str__(self): ) """The global search configuration.""" + drift_search: DRIFTSearchConfig = Field( + description="The drift search configuration.", default=DRIFTSearchConfig() + ) + """The drift search configuration.""" + encoding_model: str = Field( description="The encoding model to use.", default=defs.ENCODING_MODEL ) diff --git a/graphrag/config/models/local_search_config.py b/graphrag/config/models/local_search_config.py index c41344daef..c3d8e822e4 100644 --- a/graphrag/config/models/local_search_config.py +++ b/graphrag/config/models/local_search_config.py @@ -11,6 +11,9 @@ class LocalSearchConfig(BaseModel): """The default configuration section for Cache.""" + prompt: str | None = Field( + description="The local search prompt to use.", default=None + ) text_unit_prop: float = Field( description="The text unit proportion.", default=defs.LOCAL_SEARCH_TEXT_UNIT_PROP, diff --git a/graphrag/index/graph/extractors/__init__.py b/graphrag/index/graph/extractors/__init__.py index 9168d5e207..511695aea6 100644 --- a/graphrag/index/graph/extractors/__init__.py +++ b/graphrag/index/graph/extractors/__init__.py @@ -3,16 +3,13 @@ """The Indexing Engine graph extractors package root.""" -from .claims import CLAIM_EXTRACTION_PROMPT, ClaimExtractor +from .claims import ClaimExtractor from .community_reports import ( - COMMUNITY_REPORT_PROMPT, CommunityReportsExtractor, ) from .graph import GraphExtractionResult, GraphExtractor __all__ = [ - "CLAIM_EXTRACTION_PROMPT", - "COMMUNITY_REPORT_PROMPT", "ClaimExtractor", "CommunityReportsExtractor", "GraphExtractionResult", diff --git a/graphrag/index/graph/extractors/claims/__init__.py b/graphrag/index/graph/extractors/claims/__init__.py index 3977c8ff83..3a5a22fdb1 100644 --- a/graphrag/index/graph/extractors/claims/__init__.py +++ b/graphrag/index/graph/extractors/claims/__init__.py @@ -4,6 +4,5 @@ """The Indexing Engine graph extractors claims package root.""" from .claim_extractor import ClaimExtractor -from .prompts import CLAIM_EXTRACTION_PROMPT -__all__ = ["CLAIM_EXTRACTION_PROMPT", "ClaimExtractor"] +__all__ = ["ClaimExtractor"] diff --git a/graphrag/index/graph/extractors/claims/claim_extractor.py b/graphrag/index/graph/extractors/claims/claim_extractor.py index a1881abc86..2842ad7e1a 100644 --- a/graphrag/index/graph/extractors/claims/claim_extractor.py +++ b/graphrag/index/graph/extractors/claims/claim_extractor.py @@ -13,8 +13,7 @@ import graphrag.config.defaults as defs from graphrag.index.typing import ErrorHandlerFn from graphrag.llm import CompletionLLM - -from .prompts import ( +from graphrag.prompts.index.claim_extraction import ( CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT, diff --git a/graphrag/index/graph/extractors/community_reports/__init__.py b/graphrag/index/graph/extractors/community_reports/__init__.py index 599f56d60f..da3bf8396a 100644 --- a/graphrag/index/graph/extractors/community_reports/__init__.py +++ b/graphrag/index/graph/extractors/community_reports/__init__.py @@ -8,7 +8,6 @@ from .build_mixed_context import build_mixed_context from .community_reports_extractor import CommunityReportsExtractor from .prep_community_report_context import prep_community_report_context -from .prompts import COMMUNITY_REPORT_PROMPT from .sort_context import sort_context from .utils import ( filter_claims_to_nodes, @@ -20,7 +19,6 @@ ) __all__ = [ - "COMMUNITY_REPORT_PROMPT", "CommunityReportsExtractor", "build_mixed_context", "filter_claims_to_nodes", diff --git a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py index b54f613144..291e61af69 100644 --- a/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py +++ b/graphrag/index/graph/extractors/community_reports/community_reports_extractor.py @@ -11,8 +11,7 @@ from graphrag.index.typing import ErrorHandlerFn from graphrag.index.utils import dict_has_keys_with_types from graphrag.llm import CompletionLLM - -from .prompts import COMMUNITY_REPORT_PROMPT +from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT log = logging.getLogger(__name__) diff --git a/graphrag/index/graph/extractors/graph/__init__.py b/graphrag/index/graph/extractors/graph/__init__.py index 94e03ab9f7..7f8d19c9ca 100644 --- a/graphrag/index/graph/extractors/graph/__init__.py +++ b/graphrag/index/graph/extractors/graph/__init__.py @@ -8,11 +8,9 @@ GraphExtractionResult, GraphExtractor, ) -from .prompts import GRAPH_EXTRACTION_PROMPT __all__ = [ "DEFAULT_ENTITY_TYPES", - "GRAPH_EXTRACTION_PROMPT", "GraphExtractionResult", "GraphExtractor", ] diff --git a/graphrag/index/graph/extractors/graph/graph_extractor.py b/graphrag/index/graph/extractors/graph/graph_extractor.py index 87bd8a4aa6..b669cfa004 100644 --- a/graphrag/index/graph/extractors/graph/graph_extractor.py +++ b/graphrag/index/graph/extractors/graph/graph_extractor.py @@ -17,8 +17,11 @@ from graphrag.index.typing import ErrorHandlerFn from graphrag.index.utils import clean_str from graphrag.llm import CompletionLLM - -from .prompts import CONTINUE_PROMPT, GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT +from graphrag.prompts.index.entity_extraction import ( + CONTINUE_PROMPT, + GRAPH_EXTRACTION_PROMPT, + LOOP_PROMPT, +) DEFAULT_TUPLE_DELIMITER = "<|>" DEFAULT_RECORD_DELIMITER = "##" diff --git a/graphrag/index/graph/extractors/summarize/__init__.py b/graphrag/index/graph/extractors/summarize/__init__.py index b4bfe5be87..17fe5095aa 100644 --- a/graphrag/index/graph/extractors/summarize/__init__.py +++ b/graphrag/index/graph/extractors/summarize/__init__.py @@ -7,6 +7,5 @@ SummarizationResult, SummarizeExtractor, ) -from .prompts import SUMMARIZE_PROMPT -__all__ = ["SUMMARIZE_PROMPT", "SummarizationResult", "SummarizeExtractor"] +__all__ = ["SummarizationResult", "SummarizeExtractor"] diff --git a/graphrag/index/graph/extractors/summarize/description_summary_extractor.py b/graphrag/index/graph/extractors/summarize/description_summary_extractor.py index 9319dba401..8d845d658e 100644 --- a/graphrag/index/graph/extractors/summarize/description_summary_extractor.py +++ b/graphrag/index/graph/extractors/summarize/description_summary_extractor.py @@ -9,8 +9,7 @@ from graphrag.index.typing import ErrorHandlerFn from graphrag.index.utils.tokens import num_tokens_from_string from graphrag.llm import CompletionLLM - -from .prompts import SUMMARIZE_PROMPT +from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT # Max token size for input prompts DEFAULT_MAX_INPUT_TOKENS = 4_000 diff --git a/graphrag/index/init_content.py b/graphrag/index/init_content.py index 6db1843dc4..93807672f2 100644 --- a/graphrag/index/init_content.py +++ b/graphrag/index/init_content.py @@ -158,6 +158,7 @@ transient: false local_search: + prompt: "prompts/local_search_system_prompt.txt" # text_unit_prop: {defs.LOCAL_SEARCH_TEXT_UNIT_PROP} # community_prop: {defs.LOCAL_SEARCH_COMMUNITY_PROP} # conversation_history_max_turns: {defs.LOCAL_SEARCH_CONVERSATION_HISTORY_MAX_TURNS} @@ -169,6 +170,9 @@ # max_tokens: {defs.LOCAL_SEARCH_MAX_TOKENS} global_search: + map_prompt: "prompts/global_search_map_system_prompt.txt" + reduce_prompt: "prompts/global_search_reduce_system_prompt.txt" + knowledge_prompt: "prompts/global_search_knowledge_system_prompt.txt" # llm_temperature: {defs.GLOBAL_SEARCH_LLM_TEMPERATURE} # temperature for sampling # llm_top_p: {defs.GLOBAL_SEARCH_LLM_TOP_P} # top-p sampling # llm_n: {defs.GLOBAL_SEARCH_LLM_N} # Number of completions to generate @@ -177,6 +181,28 @@ # map_max_tokens: {defs.GLOBAL_SEARCH_MAP_MAX_TOKENS} # reduce_max_tokens: {defs.GLOBAL_SEARCH_REDUCE_MAX_TOKENS} # concurrency: {defs.GLOBAL_SEARCH_CONCURRENCY} + +drift_search: + prompt: "prompts/drift_search_system_prompt.txt" + # temperature: {defs.DRIFT_SEARCH_LLM_TEMPERATURE} + # top_p: {defs.DRIFT_SEARCH_LLM_TOP_P} + # n: {defs.DRIFT_SEARCH_LLM_N} + # max_tokens: {defs.DRIFT_SEARCH_MAX_TOKENS} + # data_max_tokens: {defs.DRIFT_SEARCH_DATA_MAX_TOKENS} + # concurrency: {defs.DRIFT_SEARCH_CONCURRENCY} + # drift_k_followups: {defs.DRIFT_SEARCH_K_FOLLOW_UPS} + # primer_folds: {defs.DRIFT_SEARCH_PRIMER_FOLDS} + # primer_llm_max_tokens: {defs.DRIFT_SEARCH_PRIMER_MAX_TOKENS} + # n_depth: {defs.DRIFT_N_DEPTH} + # local_search_text_unit_prop: {defs.DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP} + # local_search_community_prop: {defs.DRIFT_LOCAL_SEARCH_COMMUNITY_PROP} + # local_search_top_k_mapped_entities: {defs.DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES} + # local_search_top_k_relationships: {defs.DRIFT_LOCAL_SEARCH_TOP_K_RELATIONSHIPS} + # local_search_max_data_tokens: {defs.DRIFT_LOCAL_SEARCH_MAX_TOKENS} + # local_search_temperature: {defs.DRIFT_LOCAL_SEARCH_LLM_TEMPERATURE} + # local_search_top_p: {defs.DRIFT_LOCAL_SEARCH_LLM_TOP_P} + # local_search_n: {defs.DRIFT_LOCAL_SEARCH_LLM_N} + # local_search_llm_max_gen_tokens: {defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS} """ INIT_DOTENV = """\ diff --git a/graphrag/prompts/__init__.py b/graphrag/prompts/__init__.py new file mode 100644 index 0000000000..6a6123b805 --- /dev/null +++ b/graphrag/prompts/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All prompts for the system.""" diff --git a/graphrag/prompts/index/__init__.py b/graphrag/prompts/index/__init__.py new file mode 100644 index 0000000000..6377063bb6 --- /dev/null +++ b/graphrag/prompts/index/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All prompts for indexing.""" diff --git a/graphrag/index/graph/extractors/claims/prompts.py b/graphrag/prompts/index/claim_extraction.py similarity index 100% rename from graphrag/index/graph/extractors/claims/prompts.py rename to graphrag/prompts/index/claim_extraction.py diff --git a/graphrag/index/graph/extractors/community_reports/prompts.py b/graphrag/prompts/index/community_report.py similarity index 100% rename from graphrag/index/graph/extractors/community_reports/prompts.py rename to graphrag/prompts/index/community_report.py diff --git a/graphrag/index/graph/extractors/graph/prompts.py b/graphrag/prompts/index/entity_extraction.py similarity index 100% rename from graphrag/index/graph/extractors/graph/prompts.py rename to graphrag/prompts/index/entity_extraction.py diff --git a/graphrag/index/graph/extractors/summarize/prompts.py b/graphrag/prompts/index/summarize_descriptions.py similarity index 100% rename from graphrag/index/graph/extractors/summarize/prompts.py rename to graphrag/prompts/index/summarize_descriptions.py diff --git a/graphrag/prompts/query/__init__.py b/graphrag/prompts/query/__init__.py new file mode 100644 index 0000000000..79ebce00d8 --- /dev/null +++ b/graphrag/prompts/query/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All prompts for query.""" diff --git a/graphrag/query/structured_search/drift_search/system_prompt.py b/graphrag/prompts/query/drift_search_system_prompt.py similarity index 100% rename from graphrag/query/structured_search/drift_search/system_prompt.py rename to graphrag/prompts/query/drift_search_system_prompt.py diff --git a/graphrag/prompts/query/global_search_knowledge_system_prompt.py b/graphrag/prompts/query/global_search_knowledge_system_prompt.py new file mode 100644 index 0000000000..9125ef31f2 --- /dev/null +++ b/graphrag/prompts/query/global_search_knowledge_system_prompt.py @@ -0,0 +1,9 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Global Search system prompts.""" + +GENERAL_KNOWLEDGE_INSTRUCTION = """ +The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: +"This is an example sentence supported by real-world knowledge [LLM: verify]." +""" diff --git a/graphrag/query/structured_search/global_search/map_system_prompt.py b/graphrag/prompts/query/global_search_map_system_prompt.py similarity index 100% rename from graphrag/query/structured_search/global_search/map_system_prompt.py rename to graphrag/prompts/query/global_search_map_system_prompt.py diff --git a/graphrag/query/structured_search/global_search/reduce_system_prompt.py b/graphrag/prompts/query/global_search_reduce_system_prompt.py similarity index 93% rename from graphrag/query/structured_search/global_search/reduce_system_prompt.py rename to graphrag/prompts/query/global_search_reduce_system_prompt.py index 701717817c..c9dbb9188d 100644 --- a/graphrag/query/structured_search/global_search/reduce_system_prompt.py +++ b/graphrag/prompts/query/global_search_reduce_system_prompt.py @@ -81,8 +81,3 @@ NO_DATA_ANSWER = ( "I am sorry but I am unable to answer this question given the provided data." ) - -GENERAL_KNOWLEDGE_INSTRUCTION = """ -The response may also include relevant real-world knowledge outside the dataset, but it must be explicitly annotated with a verification tag [LLM: verify]. For example: -"This is an example sentence supported by real-world knowledge [LLM: verify]." -""" diff --git a/graphrag/query/structured_search/local_search/system_prompt.py b/graphrag/prompts/query/local_search_system_prompt.py similarity index 100% rename from graphrag/query/structured_search/local_search/system_prompt.py rename to graphrag/prompts/query/local_search_system_prompt.py diff --git a/graphrag/query/question_gen/system_prompt.py b/graphrag/prompts/query/question_gen_system_prompt.py similarity index 100% rename from graphrag/query/question_gen/system_prompt.py rename to graphrag/prompts/query/question_gen_system_prompt.py diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 3ba411f1ee..2e8a3afd4d 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -42,6 +42,7 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, + system_prompt: str | None = None, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -52,6 +53,7 @@ def get_local_search_engine( return LocalSearch( llm=llm, + system_prompt=system_prompt, context_builder=LocalSearchMixedContext( community_reports=reports, text_units=text_units, @@ -95,6 +97,9 @@ def get_global_search_engine( communities: list[Community], response_type: str, dynamic_community_selection: bool = False, + map_system_prompt: str | None = None, + reduce_system_prompt: str | None = None, + general_knowledge_inclusion_prompt: str | None = None, ) -> GlobalSearch: """Create a global search engine based on data + configuration.""" token_encoder = tiktoken.get_encoding(config.encoding_model) @@ -118,6 +123,9 @@ def get_global_search_engine( return GlobalSearch( llm=get_llm(config), + map_system_prompt=map_system_prompt, + reduce_system_prompt=reduce_system_prompt, + general_knowledge_inclusion_prompt=general_knowledge_inclusion_prompt, context_builder=GlobalCommunityContext( community_reports=reports, communities=communities, @@ -166,6 +174,7 @@ def get_drift_search_engine( entities: list[Entity], relationships: list[Relationship], description_embedding_store: BaseVectorStore, + local_system_prompt: str | None = None, ) -> DRIFTSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -182,6 +191,8 @@ def get_drift_search_engine( reports=reports, entity_text_embeddings=description_embedding_store, text_units=text_units, + local_system_prompt=local_system_prompt, + config=config.drift_search, ), token_encoder=token_encoder, ) diff --git a/graphrag/query/question_gen/local_gen.py b/graphrag/query/question_gen/local_gen.py index ca703a66e3..5372d4cf43 100644 --- a/graphrag/query/question_gen/local_gen.py +++ b/graphrag/query/question_gen/local_gen.py @@ -9,6 +9,7 @@ import tiktoken +from graphrag.prompts.query.question_gen_system_prompt import QUESTION_SYSTEM_PROMPT from graphrag.query.context_builder.builders import LocalContextBuilder from graphrag.query.context_builder.conversation_history import ( ConversationHistory, @@ -16,7 +17,6 @@ from graphrag.query.llm.base import BaseLLM, BaseLLMCallback from graphrag.query.llm.text_utils import num_tokens from graphrag.query.question_gen.base import BaseQuestionGen, QuestionResult -from graphrag.query.question_gen.system_prompt import QUESTION_SYSTEM_PROMPT log = logging.getLogger(__name__) diff --git a/graphrag/query/structured_search/drift_search/drift_context.py b/graphrag/query/structured_search/drift_search/drift_context.py index 64422e3760..d2a271bdac 100644 --- a/graphrag/query/structured_search/drift_search/drift_context.py +++ b/graphrag/query/structured_search/drift_search/drift_context.py @@ -19,14 +19,14 @@ Relationship, TextUnit, ) +from graphrag.prompts.query.drift_search_system_prompt import ( + DRIFT_LOCAL_SYSTEM_PROMPT, +) from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.structured_search.base import DRIFTContextBuilder from graphrag.query.structured_search.drift_search.primer import PrimerQueryProcessor -from graphrag.query.structured_search.drift_search.system_prompt import ( - DRIFT_LOCAL_SYSTEM_PROMPT, -) from graphrag.query.structured_search.local_search.mixed_context import ( LocalSearchMixedContext, ) @@ -51,7 +51,7 @@ def __init__( token_encoder: tiktoken.Encoding | None = None, embedding_vectorstore_key: str = EntityVectorStoreKey.ID, config: DRIFTSearchConfig | None = None, - local_system_prompt: str = DRIFT_LOCAL_SYSTEM_PROMPT, + local_system_prompt: str | None = None, local_mixed_context: LocalSearchMixedContext | None = None, ): """Initialize the DRIFT search context builder with necessary components.""" @@ -59,7 +59,7 @@ def __init__( self.chat_llm = chat_llm self.text_embedder = text_embedder self.token_encoder = token_encoder - self.local_system_prompt = local_system_prompt + self.local_system_prompt = local_system_prompt or DRIFT_LOCAL_SYSTEM_PROMPT self.entities = entities self.entity_text_embeddings = entity_text_embeddings diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 47c8b1e2fe..b3d2b26891 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -15,13 +15,13 @@ from graphrag.config.models.drift_config import DRIFTSearchConfig from graphrag.model import CommunityReport +from graphrag.prompts.query.drift_search_system_prompt import ( + DRIFT_PRIMER_PROMPT, +) from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import SearchResult -from graphrag.query.structured_search.drift_search.system_prompt import ( - DRIFT_PRIMER_PROMPT, -) log = logging.getLogger(__name__) diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index 530ea39d6b..6202603dfd 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -16,6 +16,16 @@ from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback from graphrag.llm.openai.utils import try_parse_json_object +from graphrag.prompts.query.global_search_knowledge_system_prompt import ( + GENERAL_KNOWLEDGE_INSTRUCTION, +) +from graphrag.prompts.query.global_search_map_system_prompt import ( + MAP_SYSTEM_PROMPT, +) +from graphrag.prompts.query.global_search_reduce_system_prompt import ( + NO_DATA_ANSWER, + REDUCE_SYSTEM_PROMPT, +) from graphrag.query.context_builder.builders import GlobalContextBuilder from graphrag.query.context_builder.conversation_history import ( ConversationHistory, @@ -23,14 +33,6 @@ from graphrag.query.llm.base import BaseLLM from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import BaseSearch, SearchResult -from graphrag.query.structured_search.global_search.map_system_prompt import ( - MAP_SYSTEM_PROMPT, -) -from graphrag.query.structured_search.global_search.reduce_system_prompt import ( - GENERAL_KNOWLEDGE_INSTRUCTION, - NO_DATA_ANSWER, - REDUCE_SYSTEM_PROMPT, -) DEFAULT_MAP_LLM_PARAMS = { "max_tokens": 1000, @@ -62,11 +64,11 @@ def __init__( llm: BaseLLM, context_builder: GlobalContextBuilder, token_encoder: tiktoken.Encoding | None = None, - map_system_prompt: str = MAP_SYSTEM_PROMPT, - reduce_system_prompt: str = REDUCE_SYSTEM_PROMPT, + map_system_prompt: str | None = None, + reduce_system_prompt: str | None = None, response_type: str = "multiple paragraphs", allow_general_knowledge: bool = False, - general_knowledge_inclusion_prompt: str = GENERAL_KNOWLEDGE_INSTRUCTION, + general_knowledge_inclusion_prompt: str | None = None, json_mode: bool = True, callbacks: list[GlobalSearchLLMCallback] | None = None, max_data_tokens: int = 8000, @@ -81,11 +83,13 @@ def __init__( token_encoder=token_encoder, context_builder_params=context_builder_params, ) - self.map_system_prompt = map_system_prompt - self.reduce_system_prompt = reduce_system_prompt + self.map_system_prompt = map_system_prompt or MAP_SYSTEM_PROMPT + self.reduce_system_prompt = reduce_system_prompt or REDUCE_SYSTEM_PROMPT self.response_type = response_type self.allow_general_knowledge = allow_general_knowledge - self.general_knowledge_inclusion_prompt = general_knowledge_inclusion_prompt + self.general_knowledge_inclusion_prompt = ( + general_knowledge_inclusion_prompt or GENERAL_KNOWLEDGE_INSTRUCTION + ) self.callbacks = callbacks self.max_data_tokens = max_data_tokens diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 1ce736027e..1e29bde759 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -10,6 +10,9 @@ import tiktoken +from graphrag.prompts.query.local_search_system_prompt import ( + LOCAL_SEARCH_SYSTEM_PROMPT, +) from graphrag.query.context_builder.builders import LocalContextBuilder from graphrag.query.context_builder.conversation_history import ( ConversationHistory, @@ -17,9 +20,6 @@ from graphrag.query.llm.base import BaseLLM, BaseLLMCallback from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import BaseSearch, SearchResult -from graphrag.query.structured_search.local_search.system_prompt import ( - LOCAL_SEARCH_SYSTEM_PROMPT, -) DEFAULT_LLM_PARAMS = { "max_tokens": 1500, @@ -37,7 +37,7 @@ def __init__( llm: BaseLLM, context_builder: LocalContextBuilder, token_encoder: tiktoken.Encoding | None = None, - system_prompt: str = LOCAL_SEARCH_SYSTEM_PROMPT, + system_prompt: str | None = None, response_type: str = "multiple paragraphs", callbacks: list[BaseLLMCallback] | None = None, llm_params: dict[str, Any] = DEFAULT_LLM_PARAMS, @@ -50,7 +50,7 @@ def __init__( llm_params=llm_params, context_builder_params=context_builder_params or {}, ) - self.system_prompt = system_prompt + self.system_prompt = system_prompt or LOCAL_SEARCH_SYSTEM_PROMPT self.callbacks = callbacks self.response_type = response_type diff --git a/tests/unit/config/test_default_config.py b/tests/unit/config/test_default_config.py index 7b5ca7db22..6e57ce3bb8 100644 --- a/tests/unit/config/test_default_config.py +++ b/tests/unit/config/test_default_config.py @@ -28,6 +28,7 @@ ClusterGraphConfigInput, CommunityReportsConfig, CommunityReportsConfigInput, + DRIFTSearchConfig, EmbedGraphConfig, EmbedGraphConfigInput, EntityExtractionConfig, @@ -202,6 +203,7 @@ def test_clear_warnings(self): assert ClaimExtractionConfig is not None assert ClusterGraphConfig is not None assert CommunityReportsConfig is not None + assert DRIFTSearchConfig is not None assert EmbedGraphConfig is not None assert EntityExtractionConfig is not None assert GlobalSearchConfig is not None