diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index d04a9c96..ed02f714 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -361,9 +361,37 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt ) opts["system"] = self.system.content if self.system is not None else None opts["plugins"] = self._plugins - llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) - async for msg in llm_ret: - yield LLMResponse(message=msg) + # llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + # async for msg in llm_ret: + # yield LLMResponse(message=msg) + # print(self.llm.extra_params.get("enable_multi_step_tool_call")) + # 流式时,无法同时处理多个工具调用 + # 所以只有关闭多步工具调用时才用流式 + if self.llm.extra_data.get("multi_step_tool_call_close", True): + llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) + else: + llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) + class MyAsyncIterator: + def __init__(self, data): + self.data = data + self.index = 0 + async def __anext__(self): + if self.index < len(self.data): + result = self.data[self.index] + self.index += 1 + return result + else: + raise StopAsyncIteration + def __aiter__(self): + return self + + llm_ret = MyAsyncIterator([llm_ret]) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args) diff --git a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py index da5ffec6..9315b8a8 100644 --- a/erniebot-agent/src/erniebot_agent/tools/remote_tool.py +++ b/erniebot-agent/src/erniebot_agent/tools/remote_tool.py @@ -6,7 +6,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Type -import requests +import httpx from erniebot_agent.file import ( FileManager, @@ -113,6 +113,9 @@ async def __call__(self, **tool_arguments: Dict[str, Any]) -> Any: return await self.__post_process__(tool_response) async def send_request(self, tool_arguments: Dict[str, Any]) -> dict: + # async http request + requests = httpx.AsyncClient(timeout=None) + url = "/".join([self.server_url.strip("/"), self.tool_view.uri.strip("/")]) url += "?version=" + self.version @@ -147,13 +150,13 @@ async def send_request(self, tool_arguments: Dict[str, Any]) -> dict: ) if self.tool_view.method == "get": - response = requests.get(url, **requests_inputs) # type: ignore + response = await requests.get(url, **requests_inputs) # type: ignore elif self.tool_view.method == "post": - response = requests.post(url, **requests_inputs) # type: ignore + response = await requests.post(url, **requests_inputs) # type: ignore elif self.tool_view.method == "put": - response = requests.put(url, **requests_inputs) # type: ignore + response = await requests.put(url, **requests_inputs) # type: ignore elif self.tool_view.method == "delete": - response = requests.delete(url, **requests_inputs) # type: ignore + response = await requests.delete(url, **requests_inputs) # type: ignore else: raise RemoteToolError(f"method<{self.tool_view.method}> is invalid", stage="Executing") diff --git a/erniebot-agent/src/erniebot_agent/tools/utils.py b/erniebot-agent/src/erniebot_agent/tools/utils.py index 85148c21..caa30a45 100644 --- a/erniebot-agent/src/erniebot_agent/tools/utils.py +++ b/erniebot-agent/src/erniebot_agent/tools/utils.py @@ -5,9 +5,9 @@ from copy import deepcopy from typing import Any, Dict, Optional, Type, no_type_check +from httpx._models import Response from openapi_spec_validator import validate from openapi_spec_validator.readers import read_from_filename -from requests import Response from erniebot_agent.file import File, FileManager from erniebot_agent.file.protocol import (