From 238f1c2adcf262ab40d6f2d69a5eb2a9f6d04047 Mon Sep 17 00:00:00 2001 From: Josh Bradley Date: Mon, 12 Aug 2024 17:09:00 -0400 Subject: [PATCH] Implement prompt tuning API (#855) * initial setup commit * cleanup API and CLI interfaces * move datatype definition to types.py * code cleanup * add semversioner file * remove unused import --------- Co-authored-by: Alonso Guevara --- .../minor-20240807023736041951.json | 4 + graphrag/prompt_tune/__main__.py | 82 +++--- graphrag/prompt_tune/api.py | 173 ++++++++++++ graphrag/prompt_tune/cli.py | 260 +++--------------- .../generator/entity_extraction_prompt.py | 10 +- .../generator/entity_summarization_prompt.py | 5 +- graphrag/prompt_tune/loader/config.py | 24 +- graphrag/prompt_tune/loader/input.py | 9 +- graphrag/prompt_tune/types.py | 19 ++ 9 files changed, 308 insertions(+), 278 deletions(-) create mode 100644 .semversioner/next-release/minor-20240807023736041951.json create mode 100644 graphrag/prompt_tune/api.py create mode 100644 graphrag/prompt_tune/types.py diff --git a/.semversioner/next-release/minor-20240807023736041951.json b/.semversioner/next-release/minor-20240807023736041951.json new file mode 100644 index 0000000000..dc4429795d --- /dev/null +++ b/.semversioner/next-release/minor-20240807023736041951.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Implement auto templating API." +} diff --git a/graphrag/prompt_tune/__main__.py b/graphrag/prompt_tune/__main__.py index e752b05a8f..cbf8dd66c4 100644 --- a/graphrag/prompt_tune/__main__.py +++ b/graphrag/prompt_tune/__main__.py @@ -1,37 +1,32 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""The Prompt auto templating package root.""" +"""The auto templating package root.""" import argparse import asyncio -from enum import Enum - -from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT -from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE +from .api import DocSelectionType from .cli import prompt_tune - - -class DocSelectionType(Enum): - """The type of document selection to use.""" - - ALL = "all" - RANDOM = "random" - TOP = "top" - AUTO = "auto" - - def __str__(self): - """Return the string representation of the enum value.""" - return self.value - +from .generator import MAX_TOKEN_COUNT +from .loader import MIN_CHUNK_SIZE if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + prog="python -m graphrag.prompt_tune", + description="The graphrag auto templating module.", + ) + + parser.add_argument( + "--config", + help="Configuration yaml file to use when generating prompts", + required=True, + type=str, + ) parser.add_argument( "--root", - help="The data project root. Including the config yml, json or .env", + help="Data project root. Default: current directory", required=False, type=str, default=".", @@ -39,15 +34,15 @@ def __str__(self): parser.add_argument( "--domain", - help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.", + help="Domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, the domain will be inferred from the input data.", required=False, default="", type=str, ) parser.add_argument( - "--method", - help="The method to select documents, one of: all, random, top or auto", + "--selection-method", + help=f"Chunk selection method. Default: {DocSelectionType.RANDOM}", required=False, type=DocSelectionType, choices=list(DocSelectionType), @@ -56,7 +51,7 @@ def __str__(self): parser.add_argument( "--n_subset_max", - help="The number of text chunks to embed when using auto selection method", + help="Number of text chunks to embed when using auto selection method. Default: 300", required=False, type=int, default=300, @@ -64,7 +59,7 @@ def __str__(self): parser.add_argument( "--k", - help="The maximum number of documents to select from each centroid when using auto selection method", + help="Maximum number of documents to select from each centroid when using auto selection method. Default: 15", required=False, type=int, default=15, @@ -72,7 +67,7 @@ def __str__(self): parser.add_argument( "--limit", - help="The limit of files to load when doing random or top selection", + help="Number of documents to load when doing random or top selection. Default: 15", type=int, required=False, default=15, @@ -80,7 +75,7 @@ def __str__(self): parser.add_argument( "--max-tokens", - help="Max token count for prompt generation", + help=f"Max token count for prompt generation. Default: {MAX_TOKEN_COUNT}", type=int, required=False, default=MAX_TOKEN_COUNT, @@ -88,7 +83,7 @@ def __str__(self): parser.add_argument( "--min-examples-required", - help="The minimum number of examples required in entity extraction prompt", + help="Minimum number of examples required in the entity extraction prompt. Default: 2", type=int, required=False, default=2, @@ -96,7 +91,7 @@ def __str__(self): parser.add_argument( "--chunk-size", - help="Max token count for prompt generation", + help=f"Max token count for prompt generation. Default: {MIN_CHUNK_SIZE}", type=int, required=False, default=MIN_CHUNK_SIZE, @@ -120,7 +115,7 @@ def __str__(self): parser.add_argument( "--output", - help="Folder to save the generated prompts to", + help="Directory to save generated prompts to. Default: 'prompts'", type=str, required=False, default="prompts", @@ -132,17 +127,18 @@ def __str__(self): loop.run_until_complete( prompt_tune( - args.root, - args.domain, - str(args.method), - args.limit, - args.max_tokens, - args.chunk_size, - args.language, - args.no_entity_types, - args.output, - args.n_subset_max, - args.k, - args.min_examples_required, + config=args.config, + root=args.root, + domain=args.domain, + selection_method=args.selection_method, + limit=args.limit, + max_tokens=args.max_tokens, + chunk_size=args.chunk_size, + language=args.language, + skip_entity_types=args.no_entity_types, + output=args.output, + n_subset_max=args.n_subset_max, + k=args.k, + min_examples_required=args.min_examples_required, ) ) diff --git a/graphrag/prompt_tune/api.py b/graphrag/prompt_tune/api.py new file mode 100644 index 0000000000..4bbcb5d7dc --- /dev/null +++ b/graphrag/prompt_tune/api.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +Auto Templating API. + +This API provides access to the auto templating feature of graphrag, allowing external applications +to hook into graphrag and generate prompts from private data. + +WARNING: This API is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from datashaper import NoopVerbCallbacks +from pydantic import PositiveInt, validate_call + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.llm import load_llm +from graphrag.index.progress import PrintProgressReporter + +from .cli import DocSelectionType +from .generator import ( + MAX_TOKEN_COUNT, + create_community_summarization_prompt, + create_entity_extraction_prompt, + create_entity_summarization_prompt, + detect_language, + generate_community_report_rating, + generate_community_reporter_role, + generate_domain, + generate_entity_relationship_examples, + generate_entity_types, + generate_persona, +) +from .loader import ( + MIN_CHUNK_SIZE, + load_docs_in_chunks, +) + + +@validate_call +async def generate_indexing_prompts( + config: GraphRagConfig, + root: str, + chunk_size: PositiveInt = MIN_CHUNK_SIZE, + limit: PositiveInt = 15, + selection_method: DocSelectionType = DocSelectionType.RANDOM, + domain: str | None = None, + language: str | None = None, + max_tokens: int = MAX_TOKEN_COUNT, + skip_entity_types: bool = False, + min_examples_required: PositiveInt = 2, + n_subset_max: PositiveInt = 300, + k: PositiveInt = 15, +) -> tuple[str, str, str]: + """Generate indexing prompts. + + Parameters + ---------- + - config: The GraphRag configuration. + - output_path: The path to store the prompts. + - chunk_size: The chunk token size to use for input text units. + - limit: The limit of chunks to load. + - selection_method: The chunk selection method. + - domain: The domain to map the input documents to. + - language: The language to use for the prompts. + - max_tokens: The maximum number of tokens to use on entity extraction prompts + - skip_entity_types: Skip generating entity types. + - min_examples_required: The minimum number of examples required for entity extraction prompts. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + + Returns + ------- + tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt + """ + reporter = PrintProgressReporter("") + + # Retrieve documents + doc_list = await load_docs_in_chunks( + root=root, + config=config, + limit=limit, + select_method=selection_method, + reporter=reporter, + chunk_size=chunk_size, + n_subset_max=n_subset_max, + k=k, + ) + + # Create LLM from config + llm = load_llm( + "prompt_tuning", + config.llm.type, + NoopVerbCallbacks(), + None, + config.llm.model_dump(), + ) + + if not domain: + reporter.info("Generating domain...") + domain = await generate_domain(llm, doc_list) + reporter.info(f"Generated domain: {domain}") + + if not language: + reporter.info("Detecting language...") + language = await detect_language(llm, doc_list) + + reporter.info("Generating persona...") + persona = await generate_persona(llm, domain) + + reporter.info("Generating community report ranking description...") + community_report_ranking = await generate_community_report_rating( + llm, domain=domain, persona=persona, docs=doc_list + ) + + entity_types = None + if not skip_entity_types: + reporter.info("Generating entity types...") + entity_types = await generate_entity_types( + llm, + domain=domain, + persona=persona, + docs=doc_list, + json_mode=config.llm.model_supports_json or False, + ) + + reporter.info("Generating entity relationship examples...") + examples = await generate_entity_relationship_examples( + llm, + persona=persona, + entity_types=entity_types, + docs=doc_list, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + ) + + reporter.info("Generating entity extraction prompt...") + entity_extraction_prompt = create_entity_extraction_prompt( + entity_types=entity_types, + docs=doc_list, + examples=examples, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json by the index engine + encoding_model=config.encoding_model, + max_token_count=max_tokens, + min_examples_required=min_examples_required, + ) + + reporter.info("Generating entity summarization prompt...") + entity_summarization_prompt = create_entity_summarization_prompt( + persona=persona, + language=language, + ) + + reporter.info("Generating community reporter role...") + community_reporter_role = await generate_community_reporter_role( + llm, domain=domain, persona=persona, docs=doc_list + ) + + reporter.info("Generating community summarization prompt...") + community_summarization_prompt = create_community_summarization_prompt( + persona=persona, + role=community_reporter_role, + report_rating_description=community_report_ranking, + language=language, + ) + + return ( + entity_extraction_prompt, + entity_summarization_prompt, + community_summarization_prompt, + ) diff --git a/graphrag/prompt_tune/cli.py b/graphrag/prompt_tune/cli.py index 5979a4a6ee..eb8ff6f49f 100644 --- a/graphrag/prompt_tune/cli.py +++ b/graphrag/prompt_tune/cli.py @@ -5,37 +5,25 @@ from pathlib import Path -from datashaper import NoopVerbCallbacks - -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.llm import load_llm from graphrag.index.progress import PrintProgressReporter -from graphrag.index.progress.types import ProgressReporter -from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.generator import ( - MAX_TOKEN_COUNT, - create_community_summarization_prompt, - create_entity_extraction_prompt, - create_entity_summarization_prompt, - detect_language, - generate_community_report_rating, - generate_community_reporter_role, - generate_domain, - generate_entity_relationship_examples, - generate_entity_types, - generate_persona, -) +from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT from graphrag.prompt_tune.loader import ( MIN_CHUNK_SIZE, - load_docs_in_chunks, read_config_parameters, ) +from . import api +from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME +from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME +from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME +from .types import DocSelectionType + async def prompt_tune( + config: str, root: str, domain: str, - select: str = "random", + selection_method: DocSelectionType = DocSelectionType.RANDOM, limit: int = 15, max_tokens: int = MAX_TOKEN_COUNT, chunk_size: int = MIN_CHUNK_SIZE, @@ -50,223 +38,51 @@ async def prompt_tune( Parameters ---------- + - config: The configuration file. - root: The root directory. - domain: The domain to map the input documents to. - - select: The chunk selection method. + - selection_method: The chunk selection method. - limit: The limit of chunks to load. - max_tokens: The maximum number of tokens to use on entity extraction prompts. - chunk_size: The chunk token size to use. + - language: The language to use for the prompts. - skip_entity_types: Skip generating entity types. - output: The output folder to store the prompts. - n_subset_max: The number of text chunks to embed when using auto selection method. - k: The number of documents to select when using auto selection method. + - min_examples_required: The minimum number of examples required for entity extraction prompts. """ reporter = PrintProgressReporter("") - config = read_config_parameters(root, reporter) - - await prompt_tune_with_config( - root, - config, - domain, - select, - limit, - max_tokens, - chunk_size, - language, - skip_entity_types, - output, - reporter, - n_subset_max, - k, - min_examples_required, - ) - - -async def prompt_tune_with_config( - root: str, - config: GraphRagConfig, - domain: str, - select: str = "random", - limit: int = 15, - max_tokens: int = MAX_TOKEN_COUNT, - chunk_size: int = MIN_CHUNK_SIZE, - language: str | None = None, - skip_entity_types: bool = False, - output: str = "prompts", - reporter: ProgressReporter | None = None, - n_subset_max: int = 300, - k: int = 15, - min_examples_required: int = 2, -): - """Prompt tune the model with a configuration. + graph_config = read_config_parameters(root, reporter, config) - Parameters - ---------- - - root: The root directory. - - config: The GraphRag configuration. - - domain: The domain to map the input documents to. - - select: The chunk selection method. - - limit: The limit of chunks to load. - - max_tokens: The maximum number of tokens to use on entity extraction prompts. - - chunk_size: The chunk token size to use for input text units. - - skip_entity_types: Skip generating entity types. - - output: The output folder to store the prompts. - - reporter: The progress reporter. - - n_subset_max: The number of text chunks to embed when using auto selection method. - - k: The number of documents to select when using auto selection method. - - Returns - ------- - - None - """ - if not reporter: - reporter = PrintProgressReporter("") - - output_path = Path(config.root_dir) / output - - doc_list = await load_docs_in_chunks( + prompts = await api.generate_indexing_prompts( + config=graph_config, root=root, - config=config, - limit=limit, - select_method=select, - reporter=reporter, chunk_size=chunk_size, + limit=limit, + selection_method=selection_method, + domain=domain, + language=language, + max_tokens=max_tokens, + skip_entity_types=skip_entity_types, + min_examples_required=min_examples_required, n_subset_max=n_subset_max, k=k, ) - # Create LLM from config - llm = load_llm( - "prompt_tuning", - config.llm.type, - NoopVerbCallbacks(), - None, - config.llm.model_dump(), - ) - - await generate_indexing_prompts( - llm, - config, - doc_list, - output_path, - reporter, - domain, - language, - max_tokens, - skip_entity_types, - min_examples_required, - ) - - -async def generate_indexing_prompts( - llm: CompletionLLM, - config: GraphRagConfig, - doc_list: list[str], - output_path: Path, - reporter: ProgressReporter, - domain: str | None = None, - language: str | None = None, - max_tokens: int = MAX_TOKEN_COUNT, - skip_entity_types: bool = False, - min_examples_required: int = 2, -): - """Generate indexing prompts. - - Parameters - ---------- - - llm: The LLM model to use. - - config: The GraphRag configuration. - - doc_list: The list of documents to use. - - output_path: The path to store the prompts. - - reporter: The progress reporter. - - domain: The domain to map the input documents to. - - max_tokens: The maximum number of tokens to use on entity extraction prompts - - skip_entity_types: Skip generating entity types. - - min_examples_required: The minimum number of examples required for entity extraction prompts. - """ - if not domain: - reporter.info("Generating domain...") - domain = await generate_domain(llm, doc_list) - reporter.info(f"Generated domain: {domain}") - - if not language: - reporter.info("Detecting language...") - language = await detect_language(llm, doc_list) - reporter.info(f"Detected language: {language}") - - reporter.info("Generating persona...") - persona = await generate_persona(llm, domain) - reporter.info(f"Generated persona: {persona}") - - reporter.info("Generating community report ranking description...") - community_report_ranking = await generate_community_report_rating( - llm, domain=domain, persona=persona, docs=doc_list - ) - reporter.info( - f"Generated community report ranking description: {community_report_ranking}" - ) - - entity_types = None - if not skip_entity_types: - reporter.info("Generating entity types") - entity_types = await generate_entity_types( - llm, - domain=domain, - persona=persona, - docs=doc_list, - json_mode=config.llm.model_supports_json or False, + output_path = Path(output) + if output_path: + reporter.info(f"Writing prompts to {output_path}") + output_path.mkdir(parents=True, exist_ok=True) + entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME + entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME + community_summarization_prompt_path = ( + output_path / COMMUNITY_SUMMARIZATION_FILENAME ) - reporter.info(f"Generated entity types: {entity_types}") - - reporter.info("Generating entity relationship examples...") - examples = await generate_entity_relationship_examples( - llm, - persona=persona, - entity_types=entity_types, - docs=doc_list, - language=language, - json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine - ) - reporter.info("Done generating entity relationship examples") - - reporter.info("Generating entity extraction prompt...") - create_entity_extraction_prompt( - entity_types=entity_types, - docs=doc_list, - examples=examples, - language=language, - json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine - output_path=output_path, - encoding_model=config.encoding_model, - max_token_count=max_tokens, - min_examples_required=min_examples_required, - ) - reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") - - reporter.info("Generating entity summarization prompt...") - create_entity_summarization_prompt( - persona=persona, - language=language, - output_path=output_path, - ) - reporter.info( - f"Generated entity summarization prompt, stored in folder {output_path}" - ) - - reporter.info("Generating community reporter role...") - community_reporter_role = await generate_community_reporter_role( - llm, domain=domain, persona=persona, docs=doc_list - ) - reporter.info(f"Generated community reporter role: {community_reporter_role}") - - reporter.info("Generating community summarization prompt...") - create_community_summarization_prompt( - persona=persona, - role=community_reporter_role, - report_rating_description=community_report_ranking, - language=language, - output_path=output_path, - ) - reporter.info( - f"Generated community summarization prompt, stored in folder {output_path}" - ) + # Write files to output path + with entity_extraction_prompt_path.open("wb") as file: + file.write(prompts[0].encode(encoding="utf-8", errors="strict")) + with entity_summarization_prompt_path.open("wb") as file: + file.write(prompts[1].encode(encoding="utf-8", errors="strict")) + with community_summarization_prompt_path.open("wb") as file: + file.write(prompts[2].encode(encoding="utf-8", errors="strict")) diff --git a/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/graphrag/prompt_tune/generator/entity_extraction_prompt.py index 3b17dbab5d..b2192c0705 100644 --- a/graphrag/prompt_tune/generator/entity_extraction_prompt.py +++ b/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -41,7 +41,7 @@ def create_entity_extraction_prompt( - encoding_model (str): The name of the model to use for token counting - max_token_count (int): The maximum number of tokens to use for the prompt - json_mode (bool): Whether to use JSON mode for the prompt. Default is False - - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - output_path (Path | None): The path to write the prompt to. Default is None. - min_examples_required (int): The minimum number of examples required. Default is 2. Returns @@ -58,8 +58,8 @@ def create_entity_extraction_prompt( tokens_left = ( max_token_count - - num_tokens_from_string(prompt, model=encoding_model) - - num_tokens_from_string(entity_types, model=encoding_model) + - num_tokens_from_string(prompt, encoding_name=encoding_model) + - num_tokens_from_string(entity_types, encoding_name=encoding_model) if entity_types else 0 ) @@ -79,7 +79,9 @@ def create_entity_extraction_prompt( ) ) - example_tokens = num_tokens_from_string(example_formatted, model=encoding_model) + example_tokens = num_tokens_from_string( + example_formatted, encoding_name=encoding_model + ) # Ensure at least three examples are included if i >= min_examples_required and example_tokens > tokens_left: diff --git a/graphrag/prompt_tune/generator/entity_summarization_prompt.py b/graphrag/prompt_tune/generator/entity_summarization_prompt.py index 4ae5af77ec..736df830d6 100644 --- a/graphrag/prompt_tune/generator/entity_summarization_prompt.py +++ b/graphrag/prompt_tune/generator/entity_summarization_prompt.py @@ -15,13 +15,14 @@ def create_entity_summarization_prompt( language: str, output_path: Path | None = None, ) -> str: - """Create a prompt for entity summarization. If output_path is provided, write the prompt to a file. + """ + Create a prompt for entity summarization. Parameters ---------- - persona (str): The persona to use for the entity summarization prompt - language (str): The language to use for the entity summarization prompt - - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - output_path (Path | None): The path to write the prompt to. Default is None. """ prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language) diff --git a/graphrag/prompt_tune/loader/config.py b/graphrag/prompt_tune/loader/config.py index 8994604f92..350feacd79 100644 --- a/graphrag/prompt_tune/loader/config.py +++ b/graphrag/prompt_tune/loader/config.py @@ -9,20 +9,38 @@ from graphrag.index.progress.types import ProgressReporter -def read_config_parameters(root: str, reporter: ProgressReporter): +def read_config_parameters( + root: str, reporter: ProgressReporter, config: str | None = None +): """Read the configuration parameters from the settings file or environment variables. Parameters ---------- - root: The root directory where the parameters are. - reporter: The progress reporter. + - config: The path to the settings file. """ _root = Path(root) - settings_yaml = _root / "settings.yaml" + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) if not settings_yaml.exists(): settings_yaml = _root / "settings.yml" - settings_json = _root / "settings.json" + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) if settings_yaml.exists(): reporter.info(f"Reading settings from {settings_yaml}") with settings_yaml.open("rb") as file: diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 86c4a76040..0679990541 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -16,6 +16,7 @@ from graphrag.index.progress.types import ProgressReporter from graphrag.index.verbs import chunk from graphrag.llm.types.llm_types import EmbeddingLLM +from graphrag.prompt_tune.types import DocSelectionType MIN_CHUNK_OVERLAP = 0 MIN_CHUNK_SIZE = 200 @@ -50,7 +51,7 @@ def _sample_chunks_from_embeddings( async def load_docs_in_chunks( root: str, config: GraphRagConfig, - select_method: str, + select_method: DocSelectionType, limit: int, reporter: ProgressReporter, chunk_size: int = MIN_CHUNK_SIZE, @@ -85,11 +86,11 @@ async def load_docs_in_chunks( if limit <= 0 or limit > len(chunks_df): limit = len(chunks_df) - if select_method == "top": + if select_method == DocSelectionType.TOP: chunks_df = chunks_df[:limit] - elif select_method == "random": + elif select_method == DocSelectionType.RANDOM: chunks_df = chunks_df.sample(n=limit) - elif select_method == "auto": + elif select_method == DocSelectionType.AUTO: if k is None or k <= 0: msg = "k must be an integer > 0" raise ValueError(msg) diff --git a/graphrag/prompt_tune/types.py b/graphrag/prompt_tune/types.py new file mode 100644 index 0000000000..1207d18767 --- /dev/null +++ b/graphrag/prompt_tune/types.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Types for prompt tuning.""" + +from enum import Enum + + +class DocSelectionType(Enum): + """The type of document selection to use.""" + + ALL = "all" + RANDOM = "random" + TOP = "top" + AUTO = "auto" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value