From d3482244a7c88868b41ef985578214e584a617e3 Mon Sep 17 00:00:00 2001 From: Chris Trevino Date: Thu, 4 Apr 2024 11:34:51 -0700 Subject: [PATCH] Fix LLM `model_parameters` argument passing (#88) * fix model parameter passing * formatting * clean up json-parsing, exception handling --- graphrag/llm/openai/openai_chat_llm.py | 76 ++++++++------------ graphrag/llm/openai/openai_completion_llm.py | 4 +- graphrag/llm/openai/openai_embeddings_llm.py | 2 +- graphrag/llm/openai/utils.py | 16 +++-- 4 files changed, 46 insertions(+), 52 deletions(-) diff --git a/graphrag/llm/openai/openai_chat_llm.py b/graphrag/llm/openai/openai_chat_llm.py index 25d64e32ef..0aa84129b7 100644 --- a/graphrag/llm/openai/openai_chat_llm.py +++ b/graphrag/llm/openai/openai_chat_llm.py @@ -2,9 +2,8 @@ """The Chat-based language model.""" -import json import logging -import traceback +from json import JSONDecodeError from typing_extensions import Unpack @@ -44,7 +43,9 @@ def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration async def _execute_llm( self, input: CompletionInput, **kwargs: Unpack[LLMInput] ) -> CompletionOutput | None: - args = get_completion_llm_args(kwargs.get("parameters"), self.configuration) + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) history = kwargs.get("history") or [] messages = [ *history, @@ -91,34 +92,25 @@ async def _native_json( self, input: CompletionInput, **kwargs: Unpack[LLMInput] ) -> LLMOutput[CompletionOutput]: """Generate JSON output using a model's native JSON-output support.""" - try: - result = await self._invoke( - input, - **{ - **kwargs, - "model_parameters": { - **(kwargs.get("model_parameters") or {}), - "response_format": {"type": "json_object"}, - }, + result = await self._invoke( + input, + **{ + **kwargs, + "model_parameters": { + **(kwargs.get("model_parameters") or {}), + "response_format": {"type": "json_object"}, }, - ) + }, + ) - raw_output = result.output or "" - json_output = json.loads(raw_output) - return LLMOutput[CompletionOutput]( - output=raw_output, - json=json_output, - history=result.history, - ) - except BaseException as e: - log.exception("error parsing llm json, emitting none") - if self._on_error: - self._on_error( - e, - traceback.format_exc(), - {"input": input, "operation": "native_json"}, - ) - raise + raw_output = result.output or "" + json_output = try_parse_json_object(raw_output) + + return LLMOutput[CompletionOutput]( + output=raw_output, + json=json_output, + history=result.history, + ) async def _manual_json( self, input: CompletionInput, **kwargs: Unpack[LLMInput] @@ -132,26 +124,18 @@ async def _manual_json( return LLMOutput[CompletionOutput]( output=output, json=json_output, history=history ) - except BaseException: - log.exception("error parsing llm json, retrying") + except (TypeError, JSONDecodeError): + log.warning("error parsing llm json, retrying") # If cleaned up json is unparsable, use the LLM to reformat it (may throw) result = await self._try_clean_json_with_llm(output, **kwargs) output = clean_up_json(result.output or "") - try: - return LLMOutput[CompletionOutput]( - output=output, - json=try_parse_json_object(output), - history=history, - ) - except Exception as e: - log.exception("error parsing llm json, emitting none") - if self._on_error: - self._on_error( - e, - traceback.format_exc(), - {"input": input, "operation": "manual_json"}, - ) - raise + json = try_parse_json_object(output) + + return LLMOutput[CompletionOutput]( + output=output, + json=json, + history=history, + ) async def _try_clean_json_with_llm( self, output: str, **kwargs: Unpack[LLMInput] diff --git a/graphrag/llm/openai/openai_completion_llm.py b/graphrag/llm/openai/openai_completion_llm.py index e6521e088b..9240beba21 100644 --- a/graphrag/llm/openai/openai_completion_llm.py +++ b/graphrag/llm/openai/openai_completion_llm.py @@ -35,6 +35,8 @@ async def _execute_llm( input: CompletionInput, **kwargs: Unpack[LLMInput], ) -> CompletionOutput | None: - args = get_completion_llm_args(kwargs.get("parameters"), self.configuration) + args = get_completion_llm_args( + kwargs.get("model_parameters"), self.configuration + ) completion = self.client.completions.create(prompt=input, **args) return completion.choices[0].text diff --git a/graphrag/llm/openai/openai_embeddings_llm.py b/graphrag/llm/openai/openai_embeddings_llm.py index 0f18b9da71..0f33777956 100644 --- a/graphrag/llm/openai/openai_embeddings_llm.py +++ b/graphrag/llm/openai/openai_embeddings_llm.py @@ -30,7 +30,7 @@ async def _execute_llm( ) -> EmbeddingOutput | None: args = { "model": self.configuration.model, - **(kwargs.get("parameters") or {}), + **(kwargs.get("model_parameters") or {}), } embedding = await self.client.embeddings.create( input=input, diff --git a/graphrag/llm/openai/utils.py b/graphrag/llm/openai/utils.py index d8fe4abac6..ad14cb9169 100644 --- a/graphrag/llm/openai/utils.py +++ b/graphrag/llm/openai/utils.py @@ -3,6 +3,7 @@ """Utility functions for the OpenAI API.""" import json +import logging from collections.abc import Callable from typing import Any @@ -26,6 +27,8 @@ ] RATE_LIMIT_ERRORS: list[type[Exception]] = [RateLimitError] +log = logging.getLogger(__name__) + def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]: """Get a function that counts the number of tokens in a string.""" @@ -85,10 +88,15 @@ def get_completion_llm_args( def try_parse_json_object(input: str) -> dict: """Generate JSON-string output using best-attempt prompting & parsing techniques.""" - result = json.loads(input) - if not isinstance(result, dict): - raise TypeError - return result + try: + result = json.loads(input) + except json.JSONDecodeError: + log.exception("error loading json, json=%s", input) + raise + else: + if not isinstance(result, dict): + raise TypeError + return result def get_sleep_time_from_error(e: Any) -> float: