Skip to content

Commit

Permalink
Add missing vars and fallbacks (#120)
Browse files Browse the repository at this point in the history
* Add missing vars and fallbacks

* Add optional flag to envar with fallbacks

* Fix smoke test by adding required flag to vars

* Fix pyright
  • Loading branch information
AlonsoGuevara authored Apr 5, 2024
1 parent 0bdef92 commit a04d5ec
Showing 1 changed file with 45 additions and 12 deletions.
57 changes: 45 additions & 12 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@
_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"


def _env_with_fallback(key: str, fallback: list[str]):
def _env_with_fallback(key: str, fallback: list[str], required=True):
for k in [key, *fallback]:
if k in os.environ:
return os.environ[k]
if not required:
return None
msg = f"None of the following environment variables found: {key}, {fallback}"
raise ValueError(msg)

Expand Down Expand Up @@ -282,29 +284,60 @@ def __get_text_units(data_dir: Path):


def __get_llm():
is_azure_client = (
os.environ.get("GRAPHRAG_LLM_TYPE", "openai_chat") != "openai_chat"
)

return ChatOpenAI(
api_key=_env_with_fallback("GRAPHRAG_API_KEY", ["OPENAI_API_KEY"]),
api_base=os.environ.get("GRAPHRAG_LLM_API_BASE", None),
api_key=_env_with_fallback(
"GRAPHRAG_LLM_API_KEY",
["GRAPHRAG_API_KEY", "OPENAI_API_KEY"],
required=True,
), # type: ignore | Since this is required, will always return a value or break
api_base=_env_with_fallback(
"GRAPHRAG_LLM_API_BASE", ["GRAPHRAG_API_BASE"], required=is_azure_client
),
model=os.environ.get("GRAPHRAG_LLM_MODEL", _DEFAULT_LLM_MODEL),
api_type=OpenaiApiType.OpenAI
if os.environ.get("GRAPHRAG_LLM_TYPE", "openai_chat") == "openai_chat"
else OpenaiApiType.AzureOpenAI,
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
deployment_name=os.environ.get(
"GRAPHRAG_LLM_DEPLOYMENT_NAME", _DEFAULT_LLM_MODEL
),
api_version=_env_with_fallback(
"GRAPHRAG_LLM_API_VERSION",
["GRAPHRAG_API_VERSION", "OPENAI_API_VERSION"],
required=is_azure_client,
),
max_retries=int(os.environ.get("GRAPHRAG_LLM_MAX_RETRIES", 20)),
)


def __get_text_embedder():
is_azure_client = (
os.environ.get("GRAPHRAG_EMBEDDING_TYPE", "openai_embedding")
!= "openai_embedding"
)

return OpenAIEmbedding(
api_key=_env_with_fallback("GRAPHRAG_API_KEY", ["OPENAI_API_KEY"]),
api_base=os.environ.get("GRAPHRAG_EMBEDDING_API_BASE", None),
api_type=OpenaiApiType.OpenAI
if os.environ.get("GRAPHRAG_EMBEDDING_TYPE", "openai_embedding")
== "openai_embedding"
else OpenaiApiType.AzureOpenAI,
api_key=_env_with_fallback(
"GRAPHRAG_EMBEDDING_API_KEY",
["GRAPHRAG_API_KEY", "OPENAI_API_KEY"],
required=True,
), # type: ignore | Since this is required, will always return a value or break
api_base=_env_with_fallback(
"GRAPHRAG_EMBEDDING_API_BASE",
["GRAPHRAG_API_BASE"],
required=is_azure_client,
),
api_type=OpenaiApiType.AzureOpenAI if is_azure_client else OpenaiApiType.OpenAI,
model=os.environ.get("GRAPHRAG_EMBEDDING_MODEL", _DEFAULT_EMBEDDING_MODEL),
deployment_name=os.environ.get(
"GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME", _DEFAULT_EMBEDDING_MODEL
),
api_version=_env_with_fallback(
"GRAPHRAG_EMBEDDING_API_VERSION",
["GRAPHRAG_API_VERSION", "OPENAI_API_VERSION"],
required=is_azure_client,
),
max_retries=int(os.environ.get("GRAPHRAG_EMBEDDING_MAX_RETRIES", 20)),
)

Expand Down

0 comments on commit a04d5ec

Please sign in to comment.