Skip to content

Commit

Permalink
[Enhancement][AIStudio]Add json response (#308)
Browse files Browse the repository at this point in the history
* add json res

* add warning

* add warning

* reformat
  • Loading branch information
Southpika authored Jan 23, 2024
1 parent 5b044b1 commit f10ad87
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
31 changes: 30 additions & 1 deletion erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import logging
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -44,6 +45,9 @@
_T = TypeVar("_T", AIMessage, AIMessageChunk)


_logger = logging.getLogger(__name__)


class BaseERNIEBot(ChatModel):
@overload
async def chat(
Expand Down Expand Up @@ -215,7 +219,15 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict
if functions is not None:
cfg_dict["functions"] = functions

name_list = ["top_p", "temperature", "penalty_score", "system", "plugins", "tool_choice"]
name_list = [
"top_p",
"temperature",
"penalty_score",
"system",
"plugins",
"tool_choice",
"response_format",
]
for name in name_list:
if name in kwargs:
cfg_dict[name] = kwargs[name]
Expand All @@ -227,6 +239,23 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict
# rm blank dict
if not cfg_dict["tool_choice"]:
cfg_dict.pop("tool_choice")

if "response_format" in cfg_dict:
if cfg_dict["response_format"] not in ("json_object", "text"):
if "json" in cfg_dict["response_format"]:
cfg_dict["response_format"] = "json_object"
_logger.warning(
f"`response_format` has invalid value:`{cfg_dict['response_format']}`, "
"use `json_object` instead. "
)
else:
# It will not raise error in request
_logger.warning(
f"`response_format` has invalid value:`{cfg_dict['response_format']}`, "
"use default value: `text`. "
"You can only choose `json_object` or `text`. "
)

return cfg_dict

def _maybe_validate_qianfan_auth(self) -> None:
Expand Down
12 changes: 12 additions & 0 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> "ChatCompletionResponse":
...
Expand All @@ -141,6 +142,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Iterator["ChatCompletionResponse"]:
...
Expand All @@ -167,6 +169,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
...
Expand All @@ -192,6 +195,7 @@ def create(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -238,6 +242,7 @@ def create(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
response_format=response_format,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand Down Expand Up @@ -271,6 +276,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = ...,
) -> EBResponse:
...
Expand All @@ -297,6 +303,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = ...,
) -> AsyncIterator["ChatCompletionResponse"]:
...
Expand All @@ -323,6 +330,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
...
Expand All @@ -348,6 +356,7 @@ async def acreate(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
"""Creates a model response for the given conversation.
Expand Down Expand Up @@ -394,6 +403,7 @@ async def acreate(
user_id=user_id,
tool_choice=tool_choice,
stream=stream,
response_format=response_format,
)
kwargs["validate_functions"] = validate_functions
if extra_params is not None:
Expand Down Expand Up @@ -438,6 +448,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"extra_params",
"headers",
"request_timeout",
"response_format",
}

invalid_keys = kwargs.keys() - valid_keys
Expand Down Expand Up @@ -500,6 +511,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
_set_val_if_key_exists(kwargs, params, "user_id")
_set_val_if_key_exists(kwargs, params, "tool_choice")
_set_val_if_key_exists(kwargs, params, "stream")
_set_val_if_key_exists(kwargs, params, "response_format")
if "extra_params" in kwargs:
params.update(kwargs["extra_params"])

Expand Down

0 comments on commit f10ad87

Please sign in to comment.