Skip to content

Commit

Permalink
feat: 添加对 agent 应用的支持 (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ committed Dec 16, 2024
1 parent 32b400d commit 6642498
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 25 deletions.
6 changes: 3 additions & 3 deletions libs/dify_service_api/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ class TestDifyClient:
async def test_chat_messages(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))

resp = await cln.chat_messages(inputs={}, query="Who are you?", user="test")
print(json.dumps(resp, ensure_ascii=False, indent=4))
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
print(json.dumps(chunk, ensure_ascii=False, indent=4))

async def test_upload_file(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
Expand Down Expand Up @@ -41,4 +41,4 @@ async def test_workflow_run(self):
print(json.dumps(chunks, ensure_ascii=False, indent=4))

if __name__ == "__main__":
asyncio.run(TestDifyClient().test_workflow_run())
asyncio.run(TestDifyClient().test_chat_messages())
27 changes: 16 additions & 11 deletions libs/dify_service_api/v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,22 @@ async def chat_messages(
inputs: dict[str, typing.Any],
query: str,
user: str,
response_mode: str = "blocking", # 当前不支持 streaming
response_mode: str = "streaming", # 当前不支持 blocking
conversation_id: str = "",
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> dict[str, typing.Any]:
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""发送消息"""
if response_mode != "blocking":
raise DifyAPIError("当前仅支持 blocking 模式")
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")

async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
response = await client.post(
async with client.stream(
"POST",
"/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json={
Expand All @@ -51,12 +52,14 @@ async def chat_messages(
"conversation_id": conversation_id,
"files": files,
},
)

if response.status_code != 200:
raise DifyAPIError(f"{response.status_code} {response.text}")

return response.json()
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
continue
if chunk.startswith("data:"):
yield json.loads(chunk[5:])

async def workflow_run(
self,
Expand Down Expand Up @@ -88,6 +91,8 @@ async def workflow_run(
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
continue
if chunk.startswith("data:"):
Expand Down
7 changes: 6 additions & 1 deletion pkg/core/migrations/m017_dify_api_timeout_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ class DifyAPITimeoutParamsMigration(migration.Migration):

async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow']
return 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['chat'] or 'timeout' not in self.ap.provider_cfg.data['dify-service-api']['workflow'] \
or 'agent' not in self.ap.provider_cfg.data['dify-service-api']

async def run(self):
"""执行迁移"""
self.ap.provider_cfg.data['dify-service-api']['chat']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['workflow']['timeout'] = 120
self.ap.provider_cfg.data['dify-service-api']['agent'] = {
"api-key": "app-1234567890",
"timeout": 120
}

await self.ap.provider_cfg.dump_config()
84 changes: 74 additions & 10 deletions pkg/provider/runners/difysvapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):

async def initialize(self):
"""初始化"""
valid_app_types = ["chat", "workflow"]
valid_app_types = ["chat", "agent", "workflow"]
if (
self.ap.provider_cfg.data["dify-service-api"]["app-type"]
not in valid_app_types
Expand Down Expand Up @@ -85,23 +85,84 @@ async def _chat_messages(
for image_id in image_ids
]

resp = await self.dify_client.chat_messages(
async for chunk in self.dify_client.chat_messages(
inputs={},
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
)
):
self.ap.logger.debug("dify-chat-chunk: "+chunk)
if chunk['event'] == 'node_finished':
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
role="assistant",
content=chunk['data']['outputs']['answer'],
)

msg = llm_entities.Message(
role="assistant",
content=resp["answer"],
)
query.session.using_conversation.uuid = chunk["conversation_id"]

async def _agent_chat_messages(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""调用聊天助手"""
cov_id = query.session.using_conversation.uuid or ""

yield msg
plain_text, image_ids = await self._preprocess_user_message(query)

files = [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": image_id,
}
for image_id in image_ids
]

query.session.using_conversation.uuid = resp["conversation_id"]
ignored_events = ["agent_message"]

async for chunk in self.dify_client.chat_messages(
inputs={},
query=plain_text,
user=f"{query.session.launcher_type.value}_{query.session.launcher_id}",
response_mode="streaming",
conversation_id=cov_id,
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["chat"]["timeout"],
):
self.ap.logger.debug("dify-agent-chunk: "+chunk)
if chunk["event"] in ignored_events:
continue
if chunk["event"] == "agent_thought":

if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
continue

if chunk['thought'].strip() != '': # 文字回复内容
msg = llm_entities.Message(
role="assistant",
content=chunk["thought"],
)
yield msg

if chunk['tool']:
msg = llm_entities.Message(
role="assistant",
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type="function",
function=llm_entities.FunctionCall(
name=chunk["tool"],
arguments=json.dumps({}),
),
)
],
)
yield msg

query.session.using_conversation.uuid = chunk["conversation_id"]

async def _workflow_messages(
self, query: core_entities.Query
Expand Down Expand Up @@ -136,7 +197,7 @@ async def _workflow_messages(
files=files,
timeout=self.ap.provider_cfg.data["dify-service-api"]["workflow"]["timeout"],
):

self.ap.logger.debug("dify-workflow-chunk: "+chunk)
if chunk["event"] in ignored_events:
continue

Expand Down Expand Up @@ -185,6 +246,9 @@ async def run(
if self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "chat":
async for msg in self._chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "agent":
async for msg in self._agent_chat_messages(query):
yield msg
elif self.ap.provider_cfg.data["dify-service-api"]["app-type"] == "workflow":
async for msg in self._workflow_messages(query):
yield msg
Expand Down
4 changes: 4 additions & 0 deletions templates/provider.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@
"api-key": "app-1234567890",
"timeout": 120
},
"agent": {
"api-key": "app-1234567890",
"timeout": 120
},
"workflow": {
"api-key": "app-1234567890",
"output-key": "summary",
Expand Down

0 comments on commit 6642498

Please sign in to comment.