Skip to content

Commit

Permalink
add configuration tests (#162)
Browse files Browse the repository at this point in the history
add some configuration tests
  • Loading branch information
darthtrevino authored Apr 18, 2024
1 parent 8e39a41 commit 5b4bb6d
Show file tree
Hide file tree
Showing 12 changed files with 357 additions and 79 deletions.
16 changes: 14 additions & 2 deletions graphrag/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
PipelineWorkflowStep,
)
from .default_config import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
CacheConfigModel,
ChunkingConfigModel,
ClaimExtractionConfigModel,
Expand All @@ -56,10 +59,18 @@
default_config_parameters,
default_config_parameters_from_env_vars,
)
from .errors import (
NoWorkflowsDefinedError,
UndefinedWorkflowError,
UnknownWorkflowError,
)
from .run import run_pipeline, run_pipeline_with_config
from .storage import PipelineStorage

__all__ = [
"ApiKeyMissingError",
"AzureApiBaseMissingError",
"AzureDeploymentNameMissingError",
"CacheConfigModel",
"ChunkingConfigModel",
"ClaimExtractionConfigModel",
Expand All @@ -71,6 +82,7 @@
"InputConfigModel",
"LLMConfigModel",
"LLMParametersModel",
"NoWorkflowsDefinedError",
"ParallelizationParametersModel",
"PipelineBlobCacheConfig",
"PipelineBlobCacheConfig",
Expand All @@ -80,7 +92,6 @@
"PipelineCache",
"PipelineCacheConfig",
"PipelineCacheConfigTypes",
# Deep Config Stack
"PipelineCacheType",
"PipelineConfig",
"PipelineConsoleReportingConfig",
Expand Down Expand Up @@ -110,7 +121,8 @@
"SummarizeDescriptionsConfigModel",
"TextEmbeddingConfigModel",
"UmapConfigModel",
# Default Config Stack
"UndefinedWorkflowError",
"UnknownWorkflowError",
"default_config",
"default_config_parameters",
"default_config_parameters_from_env_vars",
Expand Down
6 changes: 6 additions & 0 deletions graphrag/index/default_config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from .default_config import default_config
from .load import load_pipeline_config
from .parameters import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
CacheConfigModel,
ChunkingConfigModel,
ClaimExtractionConfigModel,
Expand All @@ -28,6 +31,9 @@
)

__all__ = [
"ApiKeyMissingError",
"AzureApiBaseMissingError",
"AzureDeploymentNameMissingError",
"CacheConfigModel",
"ChunkingConfigModel",
"ClaimExtractionConfigModel",
Expand Down
8 changes: 8 additions & 0 deletions graphrag/index/default_config/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
"""Configuration parameterization settings for the indexing pipeline."""

from .default_config_parameters_model import DefaultConfigParametersModel
from .errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
)
from .factories import (
default_config_parameters,
default_config_parameters_from_env_vars,
Expand All @@ -29,6 +34,9 @@
from .read_dotenv import read_dotenv

__all__ = [
"ApiKeyMissingError",
"AzureApiBaseMissingError",
"AzureDeploymentNameMissingError",
"CacheConfigModel",
"ChunkingConfigModel",
"ClaimExtractionConfigModel",
Expand Down
39 changes: 39 additions & 0 deletions graphrag/index/default_config/parameters/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
"""Errors for the default configuration."""


class ApiKeyMissingError(ValueError):
"""LLM Key missing error."""

def __init__(self, embedding: bool = False) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_key = "GRAPHRAG_EMBEDDING_API_KEY" if embedding else "GRAPHRAG_LLM_API_KEY"
msg = f"API Key is required for {api_type} API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or {api_key} environment variable."
super().__init__(msg)


class AzureApiBaseMissingError(ValueError):
"""Azure API Base missing error."""

def __init__(self, embedding: bool = False) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_base = "GRAPHRAG_EMBEDDING_API_BASE" if embedding else "GRAPHRAG_API_BASE"
msg = f"API Base is required for {api_type} API. Please set either the OPENAI_API_BASE, GRAPHRAG_API_BASE or {api_base} environment variable."
super().__init__(msg)


class AzureDeploymentNameMissingError(ValueError):
"""Azure Deployment Name missing error."""

def __init__(self, embedding: bool = False) -> None:
"""Init method definition."""
api_type = "Embedding" if embedding else "Completion"
api_base = (
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME"
if embedding
else "GRAPHRAG_LLM_DEPLOYMENT_NAME"
)
msg = f"Deployment Name is required for {api_type} API. Please set either the OPENAI_DEPLOYMENT_NAME, GRAPHRAG_LLM_DEPLOYMENT_NAME or {api_base} environment variable."
super().__init__(msg)
35 changes: 15 additions & 20 deletions graphrag/index/default_config/parameters/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

from .default_config_parameters import DefaultConfigParametersDict
from .default_config_parameters_model import DefaultConfigParametersModel
from .errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
AzureDeploymentNameMissingError,
)
from .models import (
CacheConfigModel,
ChunkingConfigModel,
Expand All @@ -40,19 +45,6 @@
UmapConfigModel,
)

LLM_KEY_REQUIRED = "API Key is required for Completion API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_LLM_API_KEY environment variable."
EMBEDDING_KEY_REQUIRED = "API Key is required for Embedding API. Please set either the OPENAI_API_KEY, GRAPHRAG_API_KEY or GRAPHRAG_EMBEDDING_API_KEY environment variable."
AZURE_LLM_DEPLOYMENT_NAME_REQUIRED = (
"GRAPHRAG_LLM_MODEL or GRAPHRAG_LLM_DEPLOYMENT_NAME is required for Azure OpenAI."
)
AZURE_LLM_API_BASE_REQUIRED = (
"GRAPHRAG_API_BASE or GRAPHRAG_LLM_API_BASE is required for Azure OpenAI."
)
AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED = "GRAPHRAG_EMBEDDING_MODEL or GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME is required for Azure OpenAI."
AZURE_EMBEDDING_API_BASE_REQUIRED = (
"GRAPHRAG_API_BASE or GRAPHRAG_EMBEDDING_API_BASE is required for Azure OpenAI."
)


def default_config_parameters(
values: DefaultConfigParametersModel, root_dir: str | None
Expand Down Expand Up @@ -103,7 +95,7 @@ def section(key: Section):
with section(Section.llm):
api_key = _str(Fragment.api_key, _api_key or fallback_oai_key)
if api_key is None:
raise ValueError(LLM_KEY_REQUIRED)
raise ApiKeyMissingError
llm_type = _str(Fragment.type)
llm_type = LLMType(llm_type) if llm_type else None
deployment_name = _str(Fragment.deployment_name)
Expand All @@ -112,9 +104,9 @@ def section(key: Section):
is_azure = _is_azure(llm_type)
api_base = _str(Fragment.api_base, _api_base)
if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_LLM_DEPLOYMENT_NAME_REQUIRED)
raise AzureDeploymentNameMissingError
if is_azure and api_base is None:
raise ValueError(AZURE_LLM_API_BASE_REQUIRED)
raise AzureApiBaseMissingError

llm_parameters = LLMParametersModel(
api_key=api_key,
Expand Down Expand Up @@ -143,7 +135,7 @@ def section(key: Section):
with section(Section.embedding):
api_key = _str(Fragment.api_key, _api_key)
if api_key is None:
raise ValueError(EMBEDDING_KEY_REQUIRED)
raise ApiKeyMissingError(embedding=True)

embedding_target = _str("TARGET")
embedding_target = (
Expand All @@ -159,9 +151,9 @@ def section(key: Section):
api_base = _str(Fragment.api_base, _api_base)

if is_azure and deployment_name is None and model is None:
raise ValueError(AZURE_EMBEDDING_DEPLOYMENT_NAME_REQUIRED)
raise AzureDeploymentNameMissingError(embedding=True)
if is_azure and api_base is None:
raise ValueError(AZURE_EMBEDDING_API_BASE_REQUIRED)
raise AzureApiBaseMissingError(embedding=True)

text_embeddings = TextEmbeddingConfigModel(
parallelization=ParallelizationParametersModel(
Expand All @@ -172,7 +164,7 @@ def section(key: Section):
target=embedding_target,
batch_size=_int("BATCH_SIZE"),
batch_max_tokens=_int("BATCH_MAX_TOKENS"),
skip=_array_string("SKIP"),
skip=_array_string(_str("SKIP")),
llm=LLMParametersModel(
api_key=_str(Fragment.api_key, _api_key),
type=llm_type,
Expand Down Expand Up @@ -244,6 +236,8 @@ def section(key: Section):
storage_type=storage_type,
file_encoding=_str(Fragment.encoding),
base_dir=_str(Fragment.base_dir),
connection_string=_str(Fragment.conn_string),
container_name=_str(Fragment.container_name),
file_pattern=_str("FILE_PATTERN"),
source_column=_str("SOURCE_COLUMN"),
timestamp_column=_str("TIMESTAMP_COLUMN"),
Expand Down Expand Up @@ -294,6 +288,7 @@ def section(key: Section):
enabled=_bool(Fragment.enabled),
)

async_mode = _str(Fragment.async_mode)
async_mode_enum = AsyncType(async_mode) if async_mode else None
return DefaultConfigParametersDict(
DefaultConfigParametersModel(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ _pyright = "pyright"
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
check_format = 'ruff format . --check --preview'
fix = "ruff --preview check --fix ."
fix_unsafe = "ruff --preview check --fix --unsafe-fixes ."
fix_unsafe = "ruff check --preview --fix --unsafe-fixes ."

_test_all = "coverage run -m pytest ./tests"
test_unit = "pytest ./tests/unit"
Expand Down
2 changes: 2 additions & 0 deletions tests/smoke/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unittest import mock

import pandas as pd
import pytest

from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage

Expand Down Expand Up @@ -247,6 +248,7 @@ def __run_query(self, root: Path, query_config: dict[str, str]):
},
clear=True,
)
@pytest.mark.timeout(600) # Extend the timeout to 600 seconds (10 minutes)
def test_fixture(
self,
input_path: str,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/indexing/default_config/prompt-a.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, World! A
1 change: 1 addition & 0 deletions tests/unit/indexing/default_config/prompt-b.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, World! B
1 change: 1 addition & 0 deletions tests/unit/indexing/default_config/prompt-c.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, World! C
1 change: 1 addition & 0 deletions tests/unit/indexing/default_config/prompt-d.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hello, World! D
Loading

0 comments on commit 5b4bb6d

Please sign in to comment.