Skip to content

Commit

Permalink
Merge pull request #963 from RockChinQ/feat/dl-image-by-adapters
Browse files Browse the repository at this point in the history
fix: 下载 QQ 图片时的400问题
  • Loading branch information
RockChinQ authored Dec 24, 2024
2 parents 535c4a8 + 07ca48d commit 8227e32
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 61 deletions.
4 changes: 2 additions & 2 deletions pkg/pipeline/preproc/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ async def process(
)
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported):
if me.url is not None:
if me.base64 is not None:
content_list.append(
llm_entities.ContentElement.from_image_url(str(me.url))
llm_entities.ContentElement.from_image_base64(me.base64)
)

query.user_message = llm_entities.Message(
Expand Down
28 changes: 15 additions & 13 deletions pkg/platform/sources/aiocqhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@
import datetime

import aiocqhttp
import aiohttp

from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities

from ...utils import image

class AiocqhttpMessageConverter(adapter.MessageConverter):

@staticmethod
def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message()

msg_id = 0
Expand Down Expand Up @@ -59,15 +60,15 @@ def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[li
elif type(msg) is forward.Forward:

for node in msg.node_list:
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
msg_list.extend(await AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])

else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))

return msg_list, msg_id, msg_time

@staticmethod
def target2yiri(message: str, message_id: int = -1):
async def target2yiri(message: str, message_id: int = -1):
message = aiocqhttp.Message(message)

yiri_msg_list = []
Expand All @@ -89,7 +90,8 @@ def target2yiri(message: str, message_id: int = -1):
elif msg.type == "text":
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image":
yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))

chain = platform_message.MessageChain(yiri_msg_list)

Expand All @@ -99,9 +101,9 @@ def target2yiri(message: str, message_id: int = -1):
class AiocqhttpEventConverter(adapter.EventConverter):

@staticmethod
def yiri2target(event: platform_events.Event, bot_account_id: int):
async def yiri2target(event: platform_events.Event, bot_account_id: int):

msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
msg, msg_id, msg_time = await AiocqhttpMessageConverter.yiri2target(event.message_chain)

if type(event) is platform_events.GroupMessage:
role = "member"
Expand Down Expand Up @@ -164,8 +166,8 @@ def yiri2target(event: platform_events.Event, bot_account_id: int):
return aiocqhttp.Event.from_payload(payload)

@staticmethod
def target2yiri(event: aiocqhttp.Event):
yiri_chain = AiocqhttpMessageConverter.target2yiri(
async def target2yiri(event: aiocqhttp.Event):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(
event.message, event.message_id
)

Expand Down Expand Up @@ -242,7 +244,7 @@ async def shutdown_trigger_placeholder():
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_msg = await AiocqhttpMessageConverter.yiri2target(message)[0]

if target_type == "group":
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
Expand All @@ -255,8 +257,8 @@ async def reply_message(
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if quote_origin:
aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg

Expand All @@ -276,7 +278,7 @@ def register_listener(
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(self.event_converter.target2yiri(event), self)
return await callback(await self.event_converter.target2yiri(event), self)
except:
traceback.print_exc()

Expand Down
6 changes: 6 additions & 0 deletions pkg/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ContentElement(pydantic.BaseModel):

image_url: typing.Optional[ImageURLContentObject] = None

image_base64: typing.Optional[str] = None

def __str__(self):
if self.type == 'text':
return self.text
Expand All @@ -53,6 +55,10 @@ def from_text(cls, text: str):
@classmethod
def from_image_url(cls, image_url: str):
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))

@classmethod
def from_image_base64(cls, image_base64: str):
return cls(type='image_base64', image_base64=image_base64)


class Message(pydantic.BaseModel):
Expand Down
1 change: 1 addition & 0 deletions pkg/provider/modelmgr/requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def preprocess(
@abc.abstractmethod
async def call(
self,
query: core_entities.Query,
model: modelmgr_entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand Down
16 changes: 7 additions & 9 deletions pkg/provider/modelmgr/requesters/anthropicmsgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing
import traceback
import base64

import anthropic
import httpx
Expand Down Expand Up @@ -39,6 +40,7 @@ async def initialize(self):

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand Down Expand Up @@ -70,24 +72,20 @@ async def call(
if isinstance(m.content, str) and m.content.strip() != "":
req_messages.append(m.dict(exclude_none=True))
elif isinstance(m.content, list):
# m.content = [
# c for c in m.content if c.type == "text"
# ]

# if len(m.content) > 0:
# req_messages.append(m.dict(exclude_none=True))

msg_dict = m.dict(exclude_none=True)

for i, ce in enumerate(m.content):
if ce.type == "image_url":
base64_image, image_format = await image.qq_image_url_to_base64(ce.image_url.url)

if ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)

alter_image_ele = {
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{image_format}",
"data": base64_image
"data": image_b64
}
}
msg_dict["content"][i] = alter_image_ele
Expand Down
20 changes: 9 additions & 11 deletions pkg/provider/modelmgr/requesters/chatcmpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ async def _make_msg(

async def _closure(
self,
query: core_entities.Query,
req_messages: list[dict],
use_model: entities.LLMModelInfo,
use_funcs: list[tools_entities.LLMFunction] = None,
Expand All @@ -87,8 +88,12 @@ async def _closure(
for msg in messages:
if 'content' in msg and isinstance(msg["content"], list):
for me in msg["content"]:
if me["type"] == "image_url":
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
if me["type"] == "image_base64":
me["image_url"] = {
"url": me["image_base64"]
}
me["type"] = "image_url"
del me["image_base64"]

args["messages"] = messages

Expand All @@ -102,6 +107,7 @@ async def _closure(

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand All @@ -118,7 +124,7 @@ async def call(
req_messages.append(msg_dict)

try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:
Expand All @@ -134,11 +140,3 @@ async def call(
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}')

@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"
22 changes: 8 additions & 14 deletions pkg/provider/modelmgr/requesters/ollamachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing import Union, Mapping, Any, AsyncIterator
import uuid
import json
import base64

import async_lru
import ollama

from .. import entities, errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....core import app
from ....core import app, entities as core_entities
from ....utils import image

REQUESTER_NAME: str = "ollama-chat"
Expand Down Expand Up @@ -43,7 +44,7 @@ async def _req(self,
**args
)

async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo,
async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo,
user_funcs: list[tools_entities.LLMFunction] = None) -> (
llm_entities.Message):
args: Any = self.request_cfg['args'].copy()
Expand All @@ -57,9 +58,9 @@ async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelI
for me in msg["content"]:
if me["type"] == "text":
text_content.append(me["text"])
elif me["type"] == "image_url":
image_url = await self.get_base64_str(me["image_url"]['url'])
image_urls.append(image_url)
elif me["type"] == "image_base64":
image_urls.append(me["image_base64"])

msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
Expand Down Expand Up @@ -109,6 +110,7 @@ async def _make_msg(

async def call(
self,
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
Expand All @@ -122,14 +124,6 @@ async def call(
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model, funcs)
return await self._closure(query, req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')

@async_lru.alru_cache(maxsize=128)
async def get_base64_str(
self,
original_url: str,
) -> str:
base64_image, image_format = await image.qq_image_url_to_base64(original_url)
return f"data:image/{image_format};base64,{base64_image}"
8 changes: 4 additions & 4 deletions pkg/provider/runners/difysvapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing
import json
import uuid
import base64

from .. import runner
from ...core import entities as core_entities
Expand Down Expand Up @@ -52,10 +53,9 @@ async def _preprocess_user_message(
for ce in query.user_message.content:
if ce.type == "text":
plain_text += ce.text
elif ce.type == "image_url":
file_bytes, image_format = await image.get_qq_image_bytes(
ce.image_url.url
)
elif ce.type == "image_base64":
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
file_bytes = base64.b64decode(image_b64)
file = ("img.png", file_bytes, f"image/{image_format}")
file_upload_resp = await self.dify_client.upload_file(
file,
Expand Down
4 changes: 2 additions & 2 deletions pkg/provider/runners/localagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]

# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)

yield msg

Expand Down Expand Up @@ -61,7 +61,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent
req_messages.append(err_msg)

# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs)

yield msg

Expand Down
Loading

0 comments on commit 8227e32

Please sign in to comment.