From 5b4bb6d303fe4362fb488f0ba5368dc518e6f9b1 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Thu, 18 Apr 2024 10:02:33 -0700 Subject: [PATCH] add configuration tests (#162) add some configuration tests --- graphrag/index/__init__.py | 16 +- graphrag/index/default_config/__init__.py | 6 + .../default_config/parameters/__init__.py | 8 + .../index/default_config/parameters/errors.py | 39 +++ .../default_config/parameters/factories.py | 35 +- pyproject.toml | 2 +- tests/smoke/test_fixtures.py | 2 + .../unit/indexing/default_config/prompt-a.txt | 1 + .../unit/indexing/default_config/prompt-b.txt | 1 + .../unit/indexing/default_config/prompt-c.txt | 1 + .../unit/indexing/default_config/prompt-d.txt | 1 + .../default_config/test_default_config.py | 324 +++++++++++++++--- 12 files changed, 357 insertions(+), 79 deletions(-) create mode 100644 graphrag/index/default_config/parameters/errors.py create mode 100644 tests/unit/indexing/default_config/prompt-a.txt create mode 100644 tests/unit/indexing/default_config/prompt-b.txt create mode 100644 tests/unit/indexing/default_config/prompt-c.txt create mode 100644 tests/unit/indexing/default_config/prompt-d.txt diff --git a/graphrag/index/__init__.py b/graphrag/index/__init__.py index e096618d98..9b61d53fff 100644 --- a/graphrag/index/__init__.py +++ b/graphrag/index/__init__.py @@ -34,6 +34,9 @@ PipelineWorkflowStep, ) from .default_config import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, CacheConfigModel, ChunkingConfigModel, ClaimExtractionConfigModel, @@ -56,10 +59,18 @@ default_config_parameters, default_config_parameters_from_env_vars, ) +from .errors import ( + NoWorkflowsDefinedError, + UndefinedWorkflowError, + UnknownWorkflowError, +) from .run import run_pipeline, run_pipeline_with_config from .storage import PipelineStorage __all__ = [ + "ApiKeyMissingError", + "AzureApiBaseMissingError", + "AzureDeploymentNameMissingError", "CacheConfigModel", "ChunkingConfigModel", "ClaimExtractionConfigModel", @@ -71,6 +82,7 @@ "InputConfigModel", "LLMConfigModel", "LLMParametersModel", + "NoWorkflowsDefinedError", "ParallelizationParametersModel", "PipelineBlobCacheConfig", "PipelineBlobCacheConfig", @@ -80,7 +92,6 @@ "PipelineCache", "PipelineCacheConfig", "PipelineCacheConfigTypes", - # Deep Config Stack "PipelineCacheType", "PipelineConfig", "PipelineConsoleReportingConfig", @@ -110,7 +121,8 @@ "SummarizeDescriptionsConfigModel", "TextEmbeddingConfigModel", "UmapConfigModel", - # Default Config Stack + "UndefinedWorkflowError", + "UnknownWorkflowError", "default_config", "default_config_parameters", "default_config_parameters_from_env_vars", diff --git a/graphrag/index/default_config/__init__.py b/graphrag/index/default_config/__init__.py index c6ca05b82f..00b81a4ca0 100644 --- a/graphrag/index/default_config/__init__.py +++ b/graphrag/index/default_config/__init__.py @@ -5,6 +5,9 @@ from .default_config import default_config from .load import load_pipeline_config from .parameters import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, CacheConfigModel, ChunkingConfigModel, ClaimExtractionConfigModel, @@ -28,6 +31,9 @@ ) __all__ = [ + "ApiKeyMissingError", + "AzureApiBaseMissingError", + "AzureDeploymentNameMissingError", "CacheConfigModel", "ChunkingConfigModel", "ClaimExtractionConfigModel", diff --git a/graphrag/index/default_config/parameters/__init__.py b/graphrag/index/default_config/parameters/__init__.py index e79d731ea5..1f0fd6b28e 100644 --- a/graphrag/index/default_config/parameters/__init__.py +++ b/graphrag/index/default_config/parameters/__init__.py @@ -3,6 +3,11 @@ """Configuration parameterization settings for the indexing pipeline.""" from .default_config_parameters_model import DefaultConfigParametersModel +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) from .factories import ( default_config_parameters, default_config_parameters_from_env_vars, @@ -29,6 +34,9 @@ from .read_dotenv import read_dotenv __all__ = [ + "ApiKeyMissingError", + "AzureApiBaseMissingError", + "AzureDeploymentNameMissingError", "CacheConfigModel", "ChunkingConfigModel", "ClaimExtractionConfigModel", diff --git a/graphrag/index/default_config/parameters/errors.py b/graphrag/index/default_config/parameters/errors.py new file mode 100644 index 0000000000..021ac1b3ac --- /dev/null +++ b/graphrag/index/default_config/parameters/errors.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Microsoft Corporation. All rights reserved. +"""Errors for the default configuration.""" + + +class ApiKeyMissingError(ValueError): + """LLM Key missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY" + msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable." + super().__init__(msg) + + +class AzureApiBaseMissingError(ValueError): + """Azure API Base missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE" + msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable." + super().__init__(msg) + + +class AzureDeploymentNameMissingError(ValueError): + """Azure Deployment Name missing error.""" + + def __init__(self, embedding: bool = False) -> None: + """Init method definition.""" + api_type = "Embedding" if embedding else "Completion" + api_base = ( + "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME" + if embedding + else "GRAPHRAG_LLM_DEPLOYMENT_NAME" + ) + msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable." + super().__init__(msg) diff --git a/graphrag/index/default_config/parameters/factories.py b/graphrag/index/default_config/parameters/factories.py index d99accb773..de6199e4a7 100644 --- a/graphrag/index/default_config/parameters/factories.py +++ b/graphrag/index/default_config/parameters/factories.py @@ -21,6 +21,11 @@ from .default_config_parameters import DefaultConfigParametersDict from .default_config_parameters_model import DefaultConfigParametersModel +from .errors import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, +) from .models import ( CacheConfigModel, ChunkingConfigModel, @@ -40,19 +45,6 @@ UmapConfigModel, ) -LLM_KEY_REQUIRED = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_LLM_API_KEY environment variable." -EMBEDDING_KEY_REQUIRED = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable." -AZURE_LLM_DEPLOYMENT_NAME_REQUIRED = ( - "GRAPHRAG_LLM_MODEL or GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI." -) -AZURE_LLM_API_BASE_REQUIRED = ( - "GRAPHRAG_API_BASE or GRAPHRAG_LLM_API_BASE is required for Azure OpenAI." -) -AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED = "GRAPHRAG_EMBEDDING_MODEL or GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI." -AZURE_EMBEDDING_API_BASE_REQUIRED = ( - "GRAPHRAG_API_BASE or GRAPHRAG_EMBEDDING_API_BASE is required for Azure OpenAI." -) - def default_config_parameters( values: DefaultConfigParametersModel, root_dir: str | None @@ -103,7 +95,7 @@ def section(key: Section): with section(Section.llm): api_key = _str(Fragment.api_key, _api_key or fallback_oai_key) if api_key is None: - raise ValueError(LLM_KEY_REQUIRED) + raise ApiKeyMissingError llm_type = _str(Fragment.type) llm_type = LLMType(llm_type) if llm_type else None deployment_name = _str(Fragment.deployment_name) @@ -112,9 +104,9 @@ def section(key: Section): is_azure = _is_azure(llm_type) api_base = _str(Fragment.api_base, _api_base) if is_azure and deployment_name is None and model is None: - raise ValueError(AZURE_LLM_DEPLOYMENT_NAME_REQUIRED) + raise AzureDeploymentNameMissingError if is_azure and api_base is None: - raise ValueError(AZURE_LLM_API_BASE_REQUIRED) + raise AzureApiBaseMissingError llm_parameters = LLMParametersModel( api_key=api_key, @@ -143,7 +135,7 @@ def section(key: Section): with section(Section.embedding): api_key = _str(Fragment.api_key, _api_key) if api_key is None: - raise ValueError(EMBEDDING_KEY_REQUIRED) + raise ApiKeyMissingError(embedding=True) embedding_target = _str("TARGET") embedding_target = ( @@ -159,9 +151,9 @@ def section(key: Section): api_base = _str(Fragment.api_base, _api_base) if is_azure and deployment_name is None and model is None: - raise ValueError(AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED) + raise AzureDeploymentNameMissingError(embedding=True) if is_azure and api_base is None: - raise ValueError(AZURE_EMBEDDING_API_BASE_REQUIRED) + raise AzureApiBaseMissingError(embedding=True) text_embeddings = TextEmbeddingConfigModel( parallelization=ParallelizationParametersModel( @@ -172,7 +164,7 @@ def section(key: Section): target=embedding_target, batch_size=_int("BATCH_SIZE"), batch_max_tokens=_int("BATCH_MAX_TOKENS"), - skip=_array_string("SKIP"), + skip=_array_string(_str("SKIP")), llm=LLMParametersModel( api_key=_str(Fragment.api_key, _api_key), type=llm_type, @@ -244,6 +236,8 @@ def section(key: Section): storage_type=storage_type, file_encoding=_str(Fragment.encoding), base_dir=_str(Fragment.base_dir), + connection_string=_str(Fragment.conn_string), + container_name=_str(Fragment.container_name), file_pattern=_str("FILE_PATTERN"), source_column=_str("SOURCE_COLUMN"), timestamp_column=_str("TIMESTAMP_COLUMN"), @@ -294,6 +288,7 @@ def section(key: Section): enabled=_bool(Fragment.enabled), ) + async_mode = _str(Fragment.async_mode) async_mode_enum = AsyncType(async_mode) if async_mode else None return DefaultConfigParametersDict( DefaultConfigParametersModel( diff --git a/pyproject.toml b/pyproject.toml index 86c8e62767..d44edd40ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ _pyright = "pyright" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' check_format = 'ruff format . --check --preview' fix = "ruff --preview check --fix ." -fix_unsafe = "ruff --preview check --fix --unsafe-fixes ." +fix_unsafe = "ruff check --preview --fix --unsafe-fixes ." _test_all = "coverage run -m pytest ./tests" test_unit = "pytest ./tests/unit" diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 3ab8015b73..01ea130f6d 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -12,6 +12,7 @@ from unittest import mock import pandas as pd +import pytest from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage @@ -247,6 +248,7 @@ def __run_query(self, root: Path, query_config: dict[str, str]): }, clear=True, ) + @pytest.mark.timeout(600) # Extend the timeout to 600 seconds (10 minutes) def test_fixture( self, input_path: str, diff --git a/tests/unit/indexing/default_config/prompt-a.txt b/tests/unit/indexing/default_config/prompt-a.txt new file mode 100644 index 0000000000..af744d0eb0 --- /dev/null +++ b/tests/unit/indexing/default_config/prompt-a.txt @@ -0,0 +1 @@ +Hello, World! A \ No newline at end of file diff --git a/tests/unit/indexing/default_config/prompt-b.txt b/tests/unit/indexing/default_config/prompt-b.txt new file mode 100644 index 0000000000..2e12b140ad --- /dev/null +++ b/tests/unit/indexing/default_config/prompt-b.txt @@ -0,0 +1 @@ +Hello, World! B \ No newline at end of file diff --git a/tests/unit/indexing/default_config/prompt-c.txt b/tests/unit/indexing/default_config/prompt-c.txt new file mode 100644 index 0000000000..f55e9771a0 --- /dev/null +++ b/tests/unit/indexing/default_config/prompt-c.txt @@ -0,0 +1 @@ +Hello, World! C \ No newline at end of file diff --git a/tests/unit/indexing/default_config/prompt-d.txt b/tests/unit/indexing/default_config/prompt-d.txt new file mode 100644 index 0000000000..bd6438515b --- /dev/null +++ b/tests/unit/indexing/default_config/prompt-d.txt @@ -0,0 +1 @@ +Hello, World! D \ No newline at end of file diff --git a/tests/unit/indexing/default_config/test_default_config.py b/tests/unit/indexing/default_config/test_default_config.py index fa31f71d49..b5b8f349dd 100644 --- a/tests/unit/indexing/default_config/test_default_config.py +++ b/tests/unit/indexing/default_config/test_default_config.py @@ -1,11 +1,18 @@ # Copyright (c) 2024 Microsoft Corporation. All rights reserved. import os +import re import unittest +from pathlib import Path from unittest import mock +import pytest import yaml +from datashaper import AsyncType from graphrag.index import ( + ApiKeyMissingError, + AzureApiBaseMissingError, + AzureDeploymentNameMissingError, CacheConfigModel, ChunkingConfigModel, ClaimExtractionConfigModel, @@ -37,8 +44,156 @@ current_dir = os.path.dirname(__file__) +ALL_ENV_VARS = { + "GRAPHRAG_API_BASE": "http://some/base", + "GRAPHRAG_API_KEY": "test", + "GRAPHRAG_API_ORGANIZATION": "test_org", + "GRAPHRAG_API_PROXY": "http://some/proxy", + "GRAPHRAG_API_VERSION": "v1234", + "GRAPHRAG_ASYNC_MODE": "asyncio", + "GRAPHRAG_CACHE_BASE_DIR": "/some/cache/dir", + "GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1", + "GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1", + "GRAPHRAG_CACHE_TYPE": "blob", + "GRAPHRAG_CHUNK_BY_COLUMNS": "a,b", + "GRAPHRAG_CHUNK_OVERLAP": "12", + "GRAPHRAG_CHUNK_SIZE": "500", + "GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", + "GRAPHRAG_CLAIM_EXTRACTION_MAX_GLEANINGS": "5000", + "GRAPHRAG_CLAIM_EXTRACTION_PROMPT_FILE": "tests/unit/indexing/default_config/prompt-a.txt", + "GRAPHRAG_COMMUNITY_REPORT_MAX_LENGTH": "23456", + "GRAPHRAG_COMMUNITY_REPORT_PROMPT_FILE": "tests/unit/indexing/default_config/prompt-b.txt", + "GRAPHRAG_EMBEDDING_BATCH_MAX_TOKENS": "17", + "GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000", + "GRAPHRAG_EMBEDDING_CONCURRENT_REQUESTS": "12", + "GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME": "model-deployment-name", + "GRAPHRAG_EMBEDDING_MAX_RETRIES": "3", + "GRAPHRAG_EMBEDDING_MAX_RETRY_WAIT": "0.1123", + "GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", + "GRAPHRAG_EMBEDDING_RPM": "500", + "GRAPHRAG_EMBEDDING_SKIP": "a1,b1,c1", + "GRAPHRAG_EMBEDDING_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", + "GRAPHRAG_EMBEDDING_TARGET": "all", + "GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", + "GRAPHRAG_EMBEDDING_THREAD_STAGGER": "0.456", + "GRAPHRAG_EMBEDDING_TPM": "7000", + "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", + "GRAPHRAG_ENCODING_MODEL": "test123", + "GRAPHRAG_ENTITY_EXTRACTION_ENTITY_TYPES": "cat,dog,elephant", + "GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", + "GRAPHRAG_ENTITY_EXTRACTION_PROMPT_FILE": "tests/unit/indexing/default_config/prompt-c.txt", + "GRAPHRAG_INPUT_BASE_DIR": "/some/input/dir", + "GRAPHRAG_INPUT_CONNECTION_STRING": "input_cs", + "GRAPHRAG_INPUT_CONTAINER_NAME": "input_cn", + "GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2", + "GRAPHRAG_INPUT_ENCODING": "utf-16", + "GRAPHRAG_INPUT_FILE_PATTERN": ".*\\test\\.txt$", + "GRAPHRAG_INPUT_SOURCE_COLUMN": "test_source", + "GRAPHRAG_INPUT_STORAGE_TYPE": "blob", + "GRAPHRAG_INPUT_TEXT_COLUMN": "test_text", + "GRAPHRAG_INPUT_TIMESTAMP_COLUMN": "test_timestamp", + "GRAPHRAG_INPUT_TIMESTAMP_FORMAT": "test_format", + "GRAPHRAG_INPUT_TITLE_COLUMN": "test_title", + "GRAPHRAG_INPUT_TYPE": "text", + "GRAPHRAG_LLM_CONCURRENT_REQUESTS": "12", + "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", + "GRAPHRAG_LLM_MAX_RETRIES": "312", + "GRAPHRAG_LLM_MAX_RETRY_WAIT": "0.1122", + "GRAPHRAG_LLM_MAX_TOKENS": "15000", + "GRAPHRAG_LLM_MODEL_SUPPORTS_JSON": "true", + "GRAPHRAG_LLM_MODEL": "test-llm", + "GRAPHRAG_LLM_REQUEST_TIMEOUT": "12.7", + "GRAPHRAG_LLM_RPM": "900", + "GRAPHRAG_LLM_SLEEP_ON_RATE_LIMIT_RECOMMENDATION": "False", + "GRAPHRAG_LLM_THREAD_COUNT": "987", + "GRAPHRAG_LLM_THREAD_STAGGER": "0.123", + "GRAPHRAG_LLM_TPM": "8000", + "GRAPHRAG_LLM_TYPE": "azure_openai_chat", + "GRAPHRAG_MAX_CLUSTER_SIZE": "123", + "GRAPHRAG_NODE2VEC_ENABLED": "true", + "GRAPHRAG_NODE2VEC_ITERATIONS": "878787", + "GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000", + "GRAPHRAG_NODE2VEC_RANDOM_SEED": "010101", + "GRAPHRAG_NODE2VEC_WALK_LENGTH": "555111", + "GRAPHRAG_NODE2VEC_WINDOW_SIZE": "12345", + "GRAPHRAG_REPORTING_BASE_DIR": "/some/reporting/dir", + "GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2", + "GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2", + "GRAPHRAG_REPORTING_TYPE": "blob", + "GRAPHRAG_SKIP_WORKFLOWS": "a,b,c", + "GRAPHRAG_SNAPSHOT_GRAPHML": "true", + "GRAPHRAG_SNAPSHOT_RAW_ENTITIES": "true", + "GRAPHRAG_SNAPSHOT_TOP_LEVEL_NODES": "true", + "GRAPHRAG_STORAGE_BASE_DIR": "/some/storage/dir", + "GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs", + "GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn", + "GRAPHRAG_STORAGE_TYPE": "blob", + "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", + "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_PROMPT_FILE": "tests/unit/indexing/default_config/prompt-d.txt", + "GRAPHRAG_UMAP_ENABLED": "true", +} + class TestDefaultConfig(unittest.TestCase): + @mock.patch.dict(os.environ, {}, clear=True) + def test_default_config_with_no_env_vars_throws(self): + with pytest.raises(ApiKeyMissingError): + # This should throw an error because the API key is missing + default_config(default_config_parameters_from_env_vars(".")) + + @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) + def test_default_config_with_api_key_passes(self): + # doesn't throw + config = default_config(default_config_parameters_from_env_vars(".")) + assert config is not None + + @mock.patch.dict(os.environ, {"OPENAI_API_KEY": "test"}, clear=True) + def test_default_config_with_oai_key_passes(self): + # doesn't throw + config = default_config(default_config_parameters_from_env_vars(".")) + assert config is not None + + @mock.patch.dict( + os.environ, + { + "GRAPHRAG_API_KEY": "test", + "GRAPHRAG_LLM_TYPE": "azure_openai_chat", + "GRAPHRAG_LLM_DEPLOYMENT_NAME": "x", + }, + clear=True, + ) + def test_throws_if_azure_is_used_without_api_base(self): + with pytest.raises(AzureApiBaseMissingError): + default_config_parameters_from_env_vars(".") + + @mock.patch.dict( + os.environ, + { + "GRAPHRAG_API_KEY": "test", + "GRAPHRAG_LLM_TYPE": "azure_openai_chat", + "GRAPHRAG_LLM_BASE": "http://some/base", + }, + clear=True, + ) + def test_throws_if_azure_is_used_without_llm_deployment_name(self): + with pytest.raises(AzureDeploymentNameMissingError): + default_config_parameters_from_env_vars(".") + + @mock.patch.dict( + os.environ, + { + "GRAPHRAG_API_KEY": "test", + "GRAPHRAG_LLM_TYPE": "azure_openai_chat", + "GRAPHRAG_API_BASE": "http://some/base", + "GRAPHRAG_LLM_DEPLOYMENT_NAME": "model-deployment-name-x", + "GRAPHRAG_EMBEDDING_TYPE": "azure_openai_embedding", + }, + clear=True, + ) + def test_throws_if_azure_is_used_without_embedding_deployment_name(self): + with pytest.raises(AzureDeploymentNameMissingError): + default_config_parameters_from_env_vars(".") + @mock.patch.dict(os.environ, {"GRAPHRAG_API_KEY": "test"}, clear=True) def test_csv_input_returns_correct_config(self): config = default_config(default_config_parameters_from_env_vars("/some/root")) @@ -61,75 +216,132 @@ def test_text_input_returns_correct_config(self): assert config.input is not None assert (config.input.file_pattern or "") == ".*\\.txt$" # type: ignore + def test_all_env_vars_is_accurate(self): + env_var_docs_path = Path("docsite/posts/config/env_vars.md") + env_var_docs = env_var_docs_path.read_text(encoding="utf-8") + + def find_envvar_names(text): + pattern = r"`(GRAPHRAG_[^`]+)`" + return re.findall(pattern, text) + + env_var_docs_path = Path("docsite/posts/config/env_vars.md") + env_var_docs = env_var_docs_path.read_text(encoding="utf-8") + graphrag_strings = find_envvar_names(env_var_docs) + + missing = { + s for s in graphrag_strings if s not in ALL_ENV_VARS and not s.endswith("_") + } + # Remove configs covered by the base LLM connection configs + missing = missing - { + "GRAPHRAG_LLM_API_KEY", + "GRAPHRAG_LLM_API_BASE", + "GRAPHRAG_LLM_API_VERSION", + "GRAPHRAG_LLM_API_ORGANIZATION", + "GRAPHRAG_LLM_API_PROXY", + "GRAPHRAG_EMBEDDING_API_KEY", + "GRAPHRAG_EMBEDDING_API_BASE", + "GRAPHRAG_EMBEDDING_API_VERSION", + "GRAPHRAG_EMBEDDING_API_ORGANIZATION", + "GRAPHRAG_EMBEDDING_API_PROXY", + } + if missing: + msg = f"{len(missing)} missing env vars: {missing}" + print(msg) + raise ValueError(msg) + @mock.patch.dict( os.environ, - { - "GRAPHRAG_API_KEY": "test", - "GRAPHRAG_LLM_MODEL": "test-llm", - "GRAPHRAG_LLM_THREAD_COUNT": "987", - "GRAPHRAG_STORAGE_TYPE": "blob", - "GRAPHRAG_STORAGE_CONNECTION_STRING": "test_cs", - "GRAPHRAG_STORAGE_CONTAINER_NAME": "test_cn", - "GRAPHRAG_CACHE_TYPE": "blob", - "GRAPHRAG_CACHE_CONNECTION_STRING": "test_cs1", - "GRAPHRAG_CACHE_CONTAINER_NAME": "test_cn1", - "GRAPHRAG_REPORTING_TYPE": "blob", - "GRAPHRAG_REPORTING_CONNECTION_STRING": "test_cs2", - "GRAPHRAG_REPORTING_CONTAINER_NAME": "test_cn2", - "GRAPHRAG_INPUT_TYPE": "text", - "GRAPHRAG_INPUT_ENCODING": "utf-16", - "GRAPHRAG_INPUT_DOCUMENT_ATTRIBUTE_COLUMNS": "test1,test2", - "GRAPHRAG_NODE2VEC_NUM_WALKS": "5000000", - "GRAPHRAG_EMBEDDING_MODEL": "text-embedding-2", - "GRAPHRAG_EMBEDDING_BATCH_SIZE": "1000000", - "GRAPHRAG_EMBEDDING_THREAD_COUNT": "2345", - "GRAPHRAG_CHUNK_SIZE": "500", - "GRAPHRAG_CHUNK_OVERLAP": "12", - "GRAPHRAG_CHUNK_BY_COLUMNS": "a,b", - "GRAPHRAG_SNAPSHOT_GRAPHML": "true", - "GRAPHRAG_ENTITY_EXTRACTION_MAX_GLEANINGS": "112", - "GRAPHRAG_SUMMARIZE_DESCRIPTIONS_MAX_LENGTH": "12345", - "GRAPHRAG_COMMUNITY_REPORT_MAX_LENGTH": "23456", - "GRAPHRAG_CLAIM_EXTRACTION_DESCRIPTION": "test 123", - "GRAPHRAG_MAX_CLUSTER_SIZE": "123", - "GRAPHRAG_UMAP_ENABLED": "true", - "GRAPHRAG_ENCODING_MODEL": "test123", - "GRAPHRAG_SKIP_WORKFLOWS": "a,b,c", - }, + ALL_ENV_VARS, clear=True, ) def test_create_parameters_from_env_vars(self) -> None: parameters = default_config_parameters_from_env_vars(".") - assert parameters.llm["api_key"] == "test" - assert parameters.llm["model"] == "test-llm" - assert parameters.parallelization["num_threads"] == 987 - assert parameters.encoding_model == "test123" - assert parameters.skip_workflows == ["a", "b", "c"] - assert parameters.storage.type == PipelineStorageType.blob - assert parameters.storage.connection_string == "test_cs" - assert parameters.storage.container_name == "test_cn" - assert parameters.cache.type == PipelineCacheType.blob + assert parameters.async_mode == AsyncType.AsyncIO + assert parameters.cache.base_dir == "/some/cache/dir" assert parameters.cache.connection_string == "test_cs1" assert parameters.cache.container_name == "test_cn1" - assert parameters.reporting.type == PipelineReportingType.blob - assert parameters.reporting.connection_string == "test_cs2" - assert parameters.reporting.container_name == "test_cn2" - assert parameters.input.type == PipelineInputType.text - assert parameters.input.file_encoding == "utf-16" - assert parameters.input.document_attribute_columns == ["test1", "test2"] + assert parameters.cache.type == PipelineCacheType.blob + assert parameters.chunks.group_by_columns == ["a", "b"] + assert parameters.chunks.overlap == 12 + assert parameters.chunks.size == 500 + assert parameters.claim_extraction.description == "test 123" + assert parameters.claim_extraction.max_gleanings == 5000 + assert parameters.claim_extraction.prompt == "Hello, World! A" + assert parameters.cluster_graph.max_cluster_size == 123 + assert parameters.community_reports.max_length == 23456 + assert parameters.community_reports.prompt == "Hello, World! B" + assert parameters.embed_graph.is_enabled + assert parameters.embed_graph.iterations == 878787 assert parameters.embed_graph.num_walks == 5_000_000 + assert parameters.embed_graph.random_seed == 10101 + assert parameters.embed_graph.walk_length == 555111 + assert parameters.embed_graph.window_size == 12345 + assert parameters.embeddings.batch_max_tokens == 17 assert parameters.embeddings.batch_size == 1_000_000 - assert parameters.embeddings.parallelization["num_threads"] == 2345 + assert parameters.embeddings.llm["concurrent_requests"] == 12 + assert parameters.embeddings.llm["deployment_name"] == "model-deployment-name" + assert parameters.embeddings.llm["max_retries"] == 3 + assert parameters.embeddings.llm["max_retry_wait"] == 0.1123 assert parameters.embeddings.llm["model"] == "text-embedding-2" - assert parameters.chunks.size == 500 - assert parameters.chunks.overlap == 12 - assert parameters.chunks.group_by_columns == ["a", "b"] - assert parameters.snapshots.graphml + assert parameters.embeddings.llm["requests_per_minute"] == 500 + assert parameters.embeddings.llm["sleep_on_rate_limit_recommendation"] is False + assert parameters.embeddings.llm["tokens_per_minute"] == 7000 + assert parameters.embeddings.llm["type"] == "azure_openai_embedding" + assert parameters.embeddings.parallelization["num_threads"] == 2345 + assert parameters.embeddings.parallelization["stagger"] == 0.456 + assert parameters.embeddings.skip == ["a1", "b1", "c1"] + assert parameters.embeddings.target == "all" + assert parameters.encoding_model == "test123" + assert parameters.entity_extraction.entity_types == ["cat", "dog", "elephant"] + assert parameters.entity_extraction.llm["api_base"] == "http://some/base" assert parameters.entity_extraction.max_gleanings == 112 + assert parameters.entity_extraction.prompt == "Hello, World! C" + assert parameters.input.base_dir == "/some/input/dir" + assert parameters.input.connection_string == "input_cs" + assert parameters.input.container_name == "input_cn" + assert parameters.input.document_attribute_columns == ["test1", "test2"] + assert parameters.input.file_encoding == "utf-16" + assert parameters.input.file_pattern == ".*\\test\\.txt$" + assert parameters.input.source_column == "test_source" + assert parameters.input.storage_type == "blob" + assert parameters.input.text_column == "test_text" + assert parameters.input.timestamp_column == "test_timestamp" + assert parameters.input.timestamp_format == "test_format" + assert parameters.input.title_column == "test_title" + assert parameters.input.type == PipelineInputType.text + assert parameters.llm["api_base"] == "http://some/base" + assert parameters.llm["api_key"] == "test" + assert parameters.llm["api_version"] == "v1234" + assert parameters.llm["concurrent_requests"] == 12 + assert parameters.llm["deployment_name"] == "model-deployment-name-x" + assert parameters.llm["max_retries"] == 312 + assert parameters.llm["max_retry_wait"] == 0.1122 + assert parameters.llm["max_tokens"] == 15000 + assert parameters.llm["model"] == "test-llm" + assert parameters.llm["model_supports_json"] + assert parameters.llm["organization"] == "test_org" + assert parameters.llm["proxy"] == "http://some/proxy" + assert parameters.llm["request_timeout"] == 12.7 + assert parameters.llm["requests_per_minute"] == 900 + assert parameters.llm["sleep_on_rate_limit_recommendation"] is False + assert parameters.llm["tokens_per_minute"] == 8000 + assert parameters.llm["type"] == "azure_openai_chat" + assert parameters.parallelization["num_threads"] == 987 + assert parameters.parallelization["stagger"] == 0.123 + assert parameters.reporting.base_dir == "/some/reporting/dir" + assert parameters.reporting.connection_string == "test_cs2" + assert parameters.reporting.container_name == "test_cn2" + assert parameters.reporting.type == PipelineReportingType.blob + assert parameters.skip_workflows == ["a", "b", "c"] + assert parameters.snapshots.graphml + assert parameters.snapshots.raw_entities + assert parameters.snapshots.top_level_nodes + assert parameters.storage.base_dir == "/some/storage/dir" + assert parameters.storage.connection_string == "test_cs" + assert parameters.storage.container_name == "test_cn" + assert parameters.storage.type == PipelineStorageType.blob assert parameters.summarize_descriptions.max_length == 12345 - assert parameters.community_reports.max_length == 23456 - assert parameters.claim_extraction.description == "test 123" - assert parameters.cluster_graph.max_cluster_size == 123 + assert parameters.summarize_descriptions.prompt == "Hello, World! D" assert parameters.umap.enabled @mock.patch.dict(os.environ, {"API_KEY_X": "test"}, clear=True)