From 6b4de3d841f7400d91e53a1ffc0214b291d2e1b7 Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 20 Aug 2024 14:42:20 -0700 Subject: [PATCH] Index API (#953) * Initial Index API - Implement main API entry point: build_index - Rely on GraphRagConfig instead of PipelineConfig - This unifies the API signature with the promt_tune and query API entry points - Derive cache settings, config, and resuming from the config and other arguments to simplify/reduce arguments to build_index - Add preflight config file validations - Add semver change * fix smoke tests * fix smoke tests * Use asyncio * Add e2e artifacts in GH actions * Remove unnecessary E2E test, and add skip_validations flag to cli * Nicer imports * Reorganize API functions. * Add license headers and module docstrings * Fix ignored ruff rule --------- Co-authored-by: Alonso Guevara --- .github/workflows/python-smoke-tests.yml | 5 - .../minor-20240819154736579383.json | 4 + graphrag/config/config_file_loader.py | 184 +++++++++++ graphrag/config/logging.py | 65 ++++ graphrag/config/resolve_timestamp_path.py | 115 +++++++ graphrag/index/__main__.py | 7 +- graphrag/index/api.py | 79 +++++ graphrag/index/cli.py | 302 +++++++----------- .../index/progress/load_progress_reporter.py | 30 ++ scripts/e2e-test.sh | 4 - 10 files changed, 593 insertions(+), 202 deletions(-) create mode 100644 .semversioner/next-release/minor-20240819154736579383.json create mode 100644 graphrag/config/config_file_loader.py create mode 100644 graphrag/config/logging.py create mode 100644 graphrag/config/resolve_timestamp_path.py create mode 100644 graphrag/index/api.py create mode 100644 graphrag/index/progress/load_progress_reporter.py delete mode 100755 scripts/e2e-test.sh diff --git a/.github/workflows/python-smoke-tests.yml b/.github/workflows/python-smoke-tests.yml index ac643e1c7c..47b975dbe9 100644 --- a/.github/workflows/python-smoke-tests.yml +++ b/.github/workflows/python-smoke-tests.yml @@ -102,8 +102,3 @@ jobs: with: name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }} path: tests/fixtures/*/output - - - name: E2E Test - if: steps.changes.outputs.python == 'true' - run: | - ./scripts/e2e-test.sh diff --git a/.semversioner/next-release/minor-20240819154736579383.json b/.semversioner/next-release/minor-20240819154736579383.json new file mode 100644 index 0000000000..e41cf13bc6 --- /dev/null +++ b/.semversioner/next-release/minor-20240819154736579383.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Implement Index API" +} diff --git a/graphrag/config/config_file_loader.py b/graphrag/config/config_file_loader.py new file mode 100644 index 0000000000..3f045cdc41 --- /dev/null +++ b/graphrag/config/config_file_loader.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load a GraphRagConfiguration from a file.""" + +import json +from abc import ABC, abstractmethod +from pathlib import Path + +import yaml + +from . import create_graphrag_config +from .models.graph_rag_config import GraphRagConfig + +_default_config_files = ["settings.yaml", "settings.yml", "settings.json"] + + +def resolve_config_path_with_root(root: str | Path) -> Path: + """Resolve the config path from the given root directory. + + Parameters + ---------- + root : str | Path + The path to the root directory containing the config file. + Searches for a default config file (settings.{yaml,yml,json}). + + Returns + ------- + Path + The resolved config file path. + + Raises + ------ + FileNotFoundError + If the config file is not found or cannot be resolved for the directory. + """ + root = Path(root) + + if not root.is_dir(): + msg = f"Invalid config path: {root} is not a directory" + raise FileNotFoundError(msg) + + for file in _default_config_files: + if (root / file).is_file(): + return root / file + + msg = f"Unable to resolve config file for parent directory: {root}" + raise FileNotFoundError(msg) + + +class ConfigFileLoader(ABC): + """Base class for loading a configuration from a file.""" + + @abstractmethod + def load_config(self, config_path: str | Path) -> GraphRagConfig: + """Load configuration from a file.""" + raise NotImplementedError + + +class ConfigYamlLoader(ConfigFileLoader): + """Load a configuration from a yaml file.""" + + def load_config(self, config_path: str | Path) -> GraphRagConfig: + """Load a configuration from a yaml file. + + Parameters + ---------- + config_path : str | Path + The path to the yaml file to load. + + Returns + ------- + GraphRagConfig + The loaded configuration. + + Raises + ------ + ValueError + If the file extension is not .yaml or .yml. + FileNotFoundError + If the config file is not found. + """ + config_path = Path(config_path) + if config_path.suffix not in [".yaml", ".yml"]: + msg = f"Invalid file extension for loading yaml config from: {config_path!s}. Expected .yaml or .yml" + raise ValueError(msg) + root_dir = str(config_path.parent) + if not config_path.is_file(): + msg = f"Config file not found: {config_path}" + raise FileNotFoundError(msg) + with config_path.open("rb") as file: + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root_dir) + + +class ConfigJsonLoader(ConfigFileLoader): + """Load a configuration from a json file.""" + + def load_config(self, config_path: str | Path) -> GraphRagConfig: + """Load a configuration from a json file. + + Parameters + ---------- + config_path : str | Path + The path to the json file to load. + + Returns + ------- + GraphRagConfig + The loaded configuration. + + Raises + ------ + ValueError + If the file extension is not .json. + FileNotFoundError + If the config file is not found. + """ + config_path = Path(config_path) + root_dir = str(config_path.parent) + if config_path.suffix != ".json": + msg = f"Invalid file extension for loading json config from: {config_path!s}. Expected .json" + raise ValueError(msg) + if not config_path.is_file(): + msg = f"Config file not found: {config_path}" + raise FileNotFoundError(msg) + with config_path.open("rb") as file: + data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root_dir) + + +def get_config_file_loader(config_path: str | Path) -> ConfigFileLoader: + """Config File Loader Factory. + + Parameters + ---------- + config_path : str | Path + The path to the config file. + + Returns + ------- + ConfigFileLoader + The config file loader for the provided config file. + + Raises + ------ + ValueError + If the config file extension is not supported. + """ + config_path = Path(config_path) + ext = config_path.suffix + match ext: + case ".yaml" | ".yml": + return ConfigYamlLoader() + case ".json": + return ConfigJsonLoader() + case _: + msg = f"Unsupported config file extension: {ext}" + raise ValueError(msg) + + +def load_config_from_file(config_path: str | Path) -> GraphRagConfig: + """Load a configuration from a file. + + Parameters + ---------- + config_path : str | Path + The path to the configuration file. + Supports .yaml, .yml, and .json config files. + + Returns + ------- + GraphRagConfig + The loaded configuration. + + Raises + ------ + ValueError + If the file extension is not supported. + FileNotFoundError + If the config file is not found. + """ + loader = get_config_file_loader(config_path) + return loader.load_config(config_path) diff --git a/graphrag/config/logging.py b/graphrag/config/logging.py new file mode 100644 index 0000000000..84d7369955 --- /dev/null +++ b/graphrag/config/logging.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Logging utilities. A unified way for enabling logging.""" + +import logging +from pathlib import Path + +from .enums import ReportingType +from .models.graph_rag_config import GraphRagConfig +from .resolve_timestamp_path import resolve_timestamp_path + + +def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None: + """Enable logging to a file. + + Parameters + ---------- + log_filepath : str | Path + The path to the log file. + verbose : bool, default=False + Whether to log debug messages. + """ + log_filepath = Path(log_filepath) + log_filepath.parent.mkdir(parents=True, exist_ok=True) + log_filepath.touch(exist_ok=True) + + logging.basicConfig( + filename=log_filepath, + filemode="a", + format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", + datefmt="%H:%M:%S", + level=logging.DEBUG if verbose else logging.INFO, + ) + + +def enable_logging_with_config( + config: GraphRagConfig, timestamp_value: str, verbose: bool = False +) -> tuple[bool, str]: + """Enable logging to a file based on the config. + + Parameters + ---------- + config : GraphRagConfig + The configuration. + timestamp_value : str + The timestamp value representing the directory to place the log files. + verbose : bool, default=False + Whether to log debug messages. + + Returns + ------- + tuple[bool, str] + A tuple of a boolean indicating if logging was enabled and the path to the log file. + (False, "") if logging was not enabled. + (True, str) if logging was enabled. + """ + if config.reporting.type == ReportingType.file: + log_path = resolve_timestamp_path( + Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log", + timestamp_value, + ) + enable_logging(log_path, verbose) + return (True, str(log_path)) + return (False, "") diff --git a/graphrag/config/resolve_timestamp_path.py b/graphrag/config/resolve_timestamp_path.py new file mode 100644 index 0000000000..492f620158 --- /dev/null +++ b/graphrag/config/resolve_timestamp_path.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Resolve timestamp variables in a path.""" + +import re +from pathlib import Path +from string import Template + + +def _resolve_timestamp_path_with_value(path: str | Path, timestamp_value: str) -> Path: + """Resolve the timestamp in the path with the given timestamp value. + + Parameters + ---------- + path : str | Path + The path containing ${timestamp} variables to resolve. + timestamp_value : str + The timestamp value used to resolve the path. + + Returns + ------- + Path + The path with ${timestamp} variables resolved to the provided timestamp value. + """ + template = Template(str(path)) + resolved_path = template.substitute(timestamp=timestamp_value) + return Path(resolved_path) + + +def _resolve_timestamp_path_with_dir( + path: str | Path, pattern: re.Pattern[str] +) -> Path: + """Resolve the timestamp in the path with the latest available timestamp directory value. + + Parameters + ---------- + path : str | Path + The path containing ${timestamp} variables to resolve. + pattern : re.Pattern[str] + The pattern to use to match the timestamp directories. + + Returns + ------- + Path + The path with ${timestamp} variables resolved to the latest available timestamp directory value. + + Raises + ------ + ValueError + If the parent directory expecting to contain timestamp directories does not exist or is not a directory. + Or if no timestamp directories are found in the parent directory that match the pattern. + """ + path = Path(path) + path_parts = path.parts + parent_dir = Path(path_parts[0]) + found_timestamp_pattern = False + for _, part in enumerate(path_parts[1:]): + if part.lower() == "${timestamp}": + found_timestamp_pattern = True + break + parent_dir = parent_dir / part + + # Path not using timestamp layout. + if not found_timestamp_pattern: + return path + + if not parent_dir.exists() or not parent_dir.is_dir(): + msg = f"Parent directory {parent_dir} does not exist or is not a directory." + raise ValueError(msg) + + timestamp_dirs = [ + d for d in parent_dir.iterdir() if d.is_dir() and pattern.match(d.name) + ] + timestamp_dirs.sort(key=lambda d: d.name, reverse=True) + if len(timestamp_dirs) == 0: + msg = f"No timestamp directories found in {parent_dir} that match {pattern.pattern}." + raise ValueError(msg) + return _resolve_timestamp_path_with_value(path, timestamp_dirs[0].name) + + +def resolve_timestamp_path( + path: str | Path, + pattern_or_timestamp_value: re.Pattern[str] | str = re.compile(r"^\d{8}-\d{6}$"), +) -> Path: + r"""Timestamp path resolver. + + Resolve the timestamp in the path with the given timestamp value or + with the latest available timestamp directory matching the given pattern. + + Parameters + ---------- + path : str | Path + The path containing ${timestamp} variables to resolve. + pattern_or_timestamp_value : re.Pattern[str] | str, default=re.compile(r"^\d{8}-\d{6}$") + The pattern to use to match the timestamp directories or the timestamp value to use. + If a string is provided, the path will be resolved with the given string value. + Otherwise, the path will be resolved with the latest available timestamp directory + that matches the given pattern. + + Returns + ------- + Path + The path with ${timestamp} variables resolved to the provided timestamp value or + the latest available timestamp directory. + + Raises + ------ + ValueError + If the parent directory expecting to contain timestamp directories does not exist or is not a directory. + Or if no timestamp directories are found in the parent directory that match the pattern. + """ + if isinstance(pattern_or_timestamp_value, str): + return _resolve_timestamp_path_with_value(path, pattern_or_timestamp_value) + return _resolve_timestamp_path_with_dir(path, pattern_or_timestamp_value) diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 578ffc9c33..0530290a63 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -68,6 +68,11 @@ help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) + parser.add_argument( + "--skip-validations", + help="Skip any preflight validation. Useful when running no LLM steps.", + action="store_true", + ) args = parser.parse_args() if args.overlay_defaults and not args.config: @@ -85,5 +90,5 @@ dryrun=args.dryrun or False, init=args.init or False, overlay_defaults=args.overlay_defaults or False, - cli=True, + skip_validations=args.skip_validations or False, ) diff --git a/graphrag/index/api.py b/graphrag/index/api.py new file mode 100644 index 0000000000..a58e832c9b --- /dev/null +++ b/graphrag/index/api.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +Indexing API for GraphRAG. + +WARNING: This API is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from graphrag.config.enums import CacheType +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.resolve_timestamp_path import resolve_timestamp_path + +from .cache.noop_pipeline_cache import NoopPipelineCache +from .create_pipeline_config import create_pipeline_config +from .emit.types import TableEmitterType +from .progress import ( + ProgressReporter, +) +from .run import run_pipeline_with_config +from .typing import PipelineRunResult + + +async def build_index( + config: GraphRagConfig, + run_id: str, + memory_profile: bool, + progress_reporter: ProgressReporter | None = None, + emit: list[str] | None = None, +) -> list[PipelineRunResult]: + """Run the pipeline with the given configuration. + + Parameters + ---------- + config : PipelineConfig + The configuration. + run_id : str + The run id. Creates a output directory with this name. + memory_profile : bool + Whether to enable memory profiling. + progress_reporter : ProgressReporter | None default=None + The progress reporter. + emit : list[str] | None default=None + The list of emitter types to emit. + Accepted values {"parquet", "csv"}. + + Returns + ------- + list[PipelineRunResult] + The list of pipeline run results + """ + try: + resolve_timestamp_path(config.storage.base_dir, run_id) + resume = True + except ValueError as _: + resume = False + pipeline_config = create_pipeline_config(config) + pipeline_cache = ( + NoopPipelineCache() if config.cache.type == CacheType.none is None else None + ) + outputs: list[PipelineRunResult] = [] + async for output in run_pipeline_with_config( + pipeline_config, + run_id=run_id, + memory_profile=memory_profile, + cache=pipeline_cache, + progress_reporter=progress_reporter, + emit=([TableEmitterType(e) for e in emit] if emit is not None else None), + is_resume_run=resume, + ): + outputs.append(output) + if progress_reporter: + if output.errors and len(output.errors) > 0: + progress_reporter.error(output.workflow) + else: + progress_reporter.success(output.workflow) + progress_reporter.info(str(output.result)) + return outputs diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 8d09951953..6dda401ca8 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -6,32 +6,28 @@ import asyncio import json import logging -import platform import sys import time import warnings from pathlib import Path -from graphrag.config import ( - create_graphrag_config, +from graphrag.config import create_graphrag_config +from graphrag.config.config_file_loader import ( + load_config_from_file, + resolve_config_path_with_root, ) -from graphrag.index import PipelineConfig, create_pipeline_config -from graphrag.index.cache import NoopPipelineCache -from graphrag.index.progress import ( - NullProgressReporter, - PrintProgressReporter, - ProgressReporter, -) -from graphrag.index.progress.rich import RichProgressReporter -from graphrag.index.run import run_pipeline_with_config -from graphrag.index.validate_config import validate_config_names +from graphrag.config.enums import CacheType +from graphrag.config.logging import enable_logging_with_config -from .emit import TableEmitterType +from .api import build_index from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT from .init_content import INIT_DOTENV, INIT_YAML +from .progress import ProgressReporter +from .progress.load_progress_reporter import load_progress_reporter +from .validate_config import validate_config_names # Ignore warnings from numba warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*") @@ -39,7 +35,7 @@ log = logging.getLogger(__name__) -def redact(input: dict) -> str: +def _redact(input: dict) -> str: """Sanitize the config json.""" # Redact any sensitive configuration @@ -56,7 +52,7 @@ def redact_dict(input: dict) -> dict: "organization", }: if value is not None: - result[key] = f"REDACTED, length {len(value)}" + result[key] = "==== REDACTED ====" elif isinstance(value, dict): result[key] = redact_dict(value) elif isinstance(value, list): @@ -69,6 +65,43 @@ def redact_dict(input: dict) -> dict: return json.dumps(redacted_dict, indent=4) +def _logger(reporter: ProgressReporter): + def info(msg: str, verbose: bool = False): + log.info(msg) + if verbose: + reporter.info(msg) + + def error(msg: str, verbose: bool = False): + log.error(msg) + if verbose: + reporter.error(msg) + + def success(msg: str, verbose: bool = False): + log.info(msg) + if verbose: + reporter.success(msg) + + return info, error, success + + +def _register_signal_handlers(reporter: ProgressReporter): + import signal + + def handle_signal(signum, _): + # Handle the signal here + reporter.info(f"Received signal {signum}, exiting...") + reporter.dispose() + for task in asyncio.all_tasks(): + task.cancel() + reporter.info("All tasks cancelled. Exiting...") + + # Register signal handlers for SIGINT and SIGHUP + signal.signal(signal.SIGINT, handle_signal) + + if sys.platform != "win32": + signal.signal(signal.SIGHUP, handle_signal) + + def index_cli( root: str, init: bool, @@ -81,99 +114,82 @@ def index_cli( emit: str | None, dryrun: bool, overlay_defaults: bool, - cli: bool = False, + skip_validations: bool, ): """Run the pipeline with the given config.""" + progress_reporter = load_progress_reporter(reporter or "rich") + info, error, success = _logger(progress_reporter) run_id = resume or time.strftime("%Y%m%d-%H%M%S") - _enable_logging(root, run_id, verbose) - progress_reporter = _get_progress_reporter(reporter) + if init: _initialize_project_at(root, progress_reporter) sys.exit(0) - if overlay_defaults: - pipeline_config: str | PipelineConfig = _create_default_config( - root, config, verbose, dryrun or False, progress_reporter + + if overlay_defaults or config: + config_path = ( + Path(root) / config if config else resolve_config_path_with_root(root) ) + default_config = load_config_from_file(config_path) else: - pipeline_config: str | PipelineConfig = config or _create_default_config( - root, None, verbose, dryrun or False, progress_reporter + try: + config_path = resolve_config_path_with_root(root) + default_config = load_config_from_file(config_path) + except FileNotFoundError: + default_config = create_graphrag_config(root_dir=root) + + if nocache: + default_config.cache.type = CacheType.none + + enabled_logging, log_path = enable_logging_with_config( + default_config, run_id, verbose + ) + if enabled_logging: + info(f"Logging enabled at {log_path}", True) + else: + info( + f"Logging not enabled for config {_redact(default_config.model_dump())}", + True, ) - cache = NoopPipelineCache() if nocache else None + + if skip_validations: + validate_config_names(progress_reporter, default_config) + + info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose) + info( + f"Using default configuration: {_redact(default_config.model_dump())}", + verbose, + ) + + if dryrun: + info("Dry run complete, exiting...", True) + sys.exit(0) + pipeline_emit = emit.split(",") if emit else None - encountered_errors = False - - # Run pre-flight validation on config model values - parameters = _read_config_parameters(root, config, progress_reporter) - validate_config_names(progress_reporter, parameters) - - def _run_workflow_async() -> None: - import signal - - def handle_signal(signum, _): - # Handle the signal here - progress_reporter.info(f"Received signal {signum}, exiting...") - progress_reporter.dispose() - for task in asyncio.all_tasks(): - task.cancel() - progress_reporter.info("All tasks cancelled. Exiting...") - - # Register signal handlers for SIGINT and SIGHUP - signal.signal(signal.SIGINT, handle_signal) - - if sys.platform != "win32": - signal.signal(signal.SIGHUP, handle_signal) - - async def execute(): - nonlocal encountered_errors - async for output in run_pipeline_with_config( - pipeline_config, - run_id=run_id, - memory_profile=memprofile, - cache=cache, - progress_reporter=progress_reporter, - emit=( - [TableEmitterType(e) for e in pipeline_emit] - if pipeline_emit - else None - ), - is_resume_run=bool(resume), - ): - if output.errors and len(output.errors) > 0: - encountered_errors = True - progress_reporter.error(output.workflow) - else: - progress_reporter.success(output.workflow) - - progress_reporter.info(str(output.result)) - - if platform.system() == "Windows": - import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error - - nest_asyncio.apply() - loop = asyncio.get_event_loop() - loop.run_until_complete(execute()) - elif sys.version_info >= (3, 11): - import uvloop # type: ignore Ignoring because on windows this will cause an error - - with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error - runner.run(execute()) - else: - import uvloop # type: ignore Ignoring because on windows this will cause an error - - uvloop.install() - asyncio.run(execute()) - - _run_workflow_async() + + _register_signal_handlers(progress_reporter) + + outputs = asyncio.run( + build_index( + default_config, + run_id, + memprofile, + progress_reporter, + pipeline_emit, + ) + ) + encountered_errors = any( + output.errors and len(output.errors) > 0 for output in outputs + ) + progress_reporter.stop() if encountered_errors: - progress_reporter.error( - "Errors occurred during the pipeline run, see logs for more details." + error( + "Errors occurred during the pipeline run, see logs for more details.", True ) else: - progress_reporter.success("All workflows completed successfully.") + success("All workflows completed successfully.", True) - if cli: - sys.exit(1 if encountered_errors else 0) + sys.exit(1 if encountered_errors else 0) def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: @@ -225,101 +241,3 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None: file.write( COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict") ) - - -def _create_default_config( - root: str, - config: str | None, - verbose: bool, - dryrun: bool, - reporter: ProgressReporter, -) -> PipelineConfig: - """Overlay default values on an existing config or create a default config if none is provided.""" - if config and not Path(config).exists(): - msg = f"Configuration file {config} does not exist" - raise ValueError - - if not Path(root).exists(): - msg = f"Root directory {root} does not exist" - raise ValueError(msg) - - parameters = _read_config_parameters(root, config, reporter) - log.info( - "using default configuration: %s", - redact(parameters.model_dump()), - ) - - if verbose or dryrun: - reporter.info(f"Using default configuration: {redact(parameters.model_dump())}") - result = create_pipeline_config(parameters, verbose) - if verbose or dryrun: - reporter.info(f"Final Config: {redact(result.model_dump())}") - - if dryrun: - reporter.info("dry run complete, exiting...") - sys.exit(0) - return result - - -def _read_config_parameters(root: str, config: str | None, reporter: ProgressReporter): - _root = Path(root) - settings_yaml = ( - Path(config) - if config and Path(config).suffix in [".yaml", ".yml"] - else _root / "settings.yaml" - ) - if not settings_yaml.exists(): - settings_yaml = _root / "settings.yml" - settings_json = ( - Path(config) - if config and Path(config).suffix == ".json" - else _root / "settings.json" - ) - - if settings_yaml.exists(): - reporter.success(f"Reading settings from {settings_yaml}") - with settings_yaml.open("rb") as file: - import yaml - - data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - if settings_json.exists(): - reporter.success(f"Reading settings from {settings_json}") - with settings_json.open("rb") as file: - import json - - data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - reporter.success("Reading settings from environment variables") - return create_graphrag_config(root_dir=root) - - -def _get_progress_reporter(reporter_type: str | None) -> ProgressReporter: - if reporter_type is None or reporter_type == "rich": - return RichProgressReporter("GraphRAG Indexer ") - if reporter_type == "print": - return PrintProgressReporter("GraphRAG Indexer ") - if reporter_type == "none": - return NullProgressReporter() - - msg = f"Invalid progress reporter type: {reporter_type}" - raise ValueError(msg) - - -def _enable_logging(root_dir: str, run_id: str, verbose: bool) -> None: - logging_file = ( - Path(root_dir) / "output" / run_id / "reports" / "indexing-engine.log" - ) - logging_file.parent.mkdir(parents=True, exist_ok=True) - - logging_file.touch(exist_ok=True) - - logging.basicConfig( - filename=str(logging_file), - filemode="a", - format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s", - datefmt="%H:%M:%S", - level=logging.DEBUG if verbose else logging.INFO, - ) diff --git a/graphrag/index/progress/load_progress_reporter.py b/graphrag/index/progress/load_progress_reporter.py new file mode 100644 index 0000000000..73e02d0848 --- /dev/null +++ b/graphrag/index/progress/load_progress_reporter.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Load a progress reporter.""" + +from .rich import RichProgressReporter +from .types import NullProgressReporter, PrintProgressReporter, ProgressReporter + + +def load_progress_reporter(reporter_type: str = "none") -> ProgressReporter: + """Load a progress reporter. + + Parameters + ---------- + reporter_type : {"rich", "print", "none"}, default=rich + The type of progress reporter to load. + + Returns + ------- + ProgressReporter + """ + if reporter_type == "rich": + return RichProgressReporter("GraphRAG Indexer ") + if reporter_type == "print": + return PrintProgressReporter("GraphRAG Indexer ") + if reporter_type == "none": + return NullProgressReporter() + + msg = f"Invalid progress reporter type: {reporter_type}" + raise ValueError(msg) diff --git a/scripts/e2e-test.sh b/scripts/e2e-test.sh deleted file mode 100755 index 5c260a148b..0000000000 --- a/scripts/e2e-test.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -# Use CLI Form -poetry run python -m graphrag.index --config ./examples/single_verb/pipeline.yml \ No newline at end of file