Skip to content

Commit

Permalink
Fix LLM model_parameters argument passing (#88)
Browse files Browse the repository at this point in the history
* fix model parameter passing

* formatting

* clean up json-parsing, exception handling
  • Loading branch information
darthtrevino authored Apr 4, 2024
1 parent 7b1ffcb commit d348224
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 52 deletions.
76 changes: 30 additions & 46 deletions graphrag/llm/openai/openai_chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

"""The Chat-based language model."""

import json
import logging
import traceback
from json import JSONDecodeError

from typing_extensions import Unpack

Expand Down Expand Up @@ -44,7 +43,9 @@ def __init__(self, client: OpenAIClientTypes, configuration: OpenAIConfiguration
async def _execute_llm(
self, input: CompletionInput, **kwargs: Unpack[LLMInput]
) -> CompletionOutput | None:
args = get_completion_llm_args(kwargs.get("parameters"), self.configuration)
args = get_completion_llm_args(
kwargs.get("model_parameters"), self.configuration
)
history = kwargs.get("history") or []
messages = [
*history,
Expand Down Expand Up @@ -91,34 +92,25 @@ async def _native_json(
self, input: CompletionInput, **kwargs: Unpack[LLMInput]
) -> LLMOutput[CompletionOutput]:
"""Generate JSON output using a model's native JSON-output support."""
try:
result = await self._invoke(
input,
**{
**kwargs,
"model_parameters": {
**(kwargs.get("model_parameters") or {}),
"response_format": {"type": "json_object"},
},
result = await self._invoke(
input,
**{
**kwargs,
"model_parameters": {
**(kwargs.get("model_parameters") or {}),
"response_format": {"type": "json_object"},
},
)
},
)

raw_output = result.output or ""
json_output = json.loads(raw_output)
return LLMOutput[CompletionOutput](
output=raw_output,
json=json_output,
history=result.history,
)
except BaseException as e:
log.exception("error parsing llm json, emitting none")
if self._on_error:
self._on_error(
e,
traceback.format_exc(),
{"input": input, "operation": "native_json"},
)
raise
raw_output = result.output or ""
json_output = try_parse_json_object(raw_output)

return LLMOutput[CompletionOutput](
output=raw_output,
json=json_output,
history=result.history,
)

async def _manual_json(
self, input: CompletionInput, **kwargs: Unpack[LLMInput]
Expand All @@ -132,26 +124,18 @@ async def _manual_json(
return LLMOutput[CompletionOutput](
output=output, json=json_output, history=history
)
except BaseException:
log.exception("error parsing llm json, retrying")
except (TypeError, JSONDecodeError):
log.warning("error parsing llm json, retrying")
# If cleaned up json is unparsable, use the LLM to reformat it (may throw)
result = await self._try_clean_json_with_llm(output, **kwargs)
output = clean_up_json(result.output or "")
try:
return LLMOutput[CompletionOutput](
output=output,
json=try_parse_json_object(output),
history=history,
)
except Exception as e:
log.exception("error parsing llm json, emitting none")
if self._on_error:
self._on_error(
e,
traceback.format_exc(),
{"input": input, "operation": "manual_json"},
)
raise
json = try_parse_json_object(output)

return LLMOutput[CompletionOutput](
output=output,
json=json,
history=history,
)

async def _try_clean_json_with_llm(
self, output: str, **kwargs: Unpack[LLMInput]
Expand Down
4 changes: 3 additions & 1 deletion graphrag/llm/openai/openai_completion_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ async def _execute_llm(
input: CompletionInput,
**kwargs: Unpack[LLMInput],
) -> CompletionOutput | None:
args = get_completion_llm_args(kwargs.get("parameters"), self.configuration)
args = get_completion_llm_args(
kwargs.get("model_parameters"), self.configuration
)
completion = self.client.completions.create(prompt=input, **args)
return completion.choices[0].text
2 changes: 1 addition & 1 deletion graphrag/llm/openai/openai_embeddings_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def _execute_llm(
) -> EmbeddingOutput | None:
args = {
"model": self.configuration.model,
**(kwargs.get("parameters") or {}),
**(kwargs.get("model_parameters") or {}),
}
embedding = await self.client.embeddings.create(
input=input,
Expand Down
16 changes: 12 additions & 4 deletions graphrag/llm/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Utility functions for the OpenAI API."""

import json
import logging
from collections.abc import Callable
from typing import Any

Expand All @@ -26,6 +27,8 @@
]
RATE_LIMIT_ERRORS: list[type[Exception]] = [RateLimitError]

log = logging.getLogger(__name__)


def get_token_counter(config: OpenAIConfiguration) -> Callable[[str], int]:
"""Get a function that counts the number of tokens in a string."""
Expand Down Expand Up @@ -85,10 +88,15 @@ def get_completion_llm_args(

def try_parse_json_object(input: str) -> dict:
"""Generate JSON-string output using best-attempt prompting & parsing techniques."""
result = json.loads(input)
if not isinstance(result, dict):
raise TypeError
return result
try:
result = json.loads(input)
except json.JSONDecodeError:
log.exception("error loading json, json=%s", input)
raise
else:
if not isinstance(result, dict):
raise TypeError
return result


def get_sleep_time_from_error(e: Any) -> float:
Expand Down

0 comments on commit d348224

Please sign in to comment.