diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 4fa32c65..d2cb5977 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -61,9 +61,9 @@ async def process( ) elif isinstance(me, platform_message.Image): if self.ap.provider_cfg.data['enable-vision'] and (self.ap.provider_cfg.data['runner'] != 'local-agent' or query.use_model.vision_supported): - if me.url is not None: + if me.base64 is not None: content_list.append( - llm_entities.ContentElement.from_image_url(str(me.url)) + llm_entities.ContentElement.from_image_base64(me.base64) ) query.user_message = llm_entities.Message( diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index 25d197e3..107bd326 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -6,6 +6,7 @@ import datetime import aiocqhttp +import aiohttp from .. import adapter from ...pipeline.longtext.strategies import forward @@ -13,12 +14,12 @@ from ..types import message as platform_message from ..types import events as platform_events from ..types import entities as platform_entities - +from ...utils import image class AiocqhttpMessageConverter(adapter.MessageConverter): @staticmethod - def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: + async def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: msg_list = aiocqhttp.Message() msg_id = 0 @@ -59,7 +60,7 @@ def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[li elif type(msg) is forward.Forward: for node in msg.node_list: - msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0]) + msg_list.extend(await AiocqhttpMessageConverter.yiri2target(node.message_chain)[0]) else: msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) @@ -67,7 +68,7 @@ def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[li return msg_list, msg_id, msg_time @staticmethod - def target2yiri(message: str, message_id: int = -1): + async def target2yiri(message: str, message_id: int = -1): message = aiocqhttp.Message(message) yiri_msg_list = [] @@ -89,7 +90,8 @@ def target2yiri(message: str, message_id: int = -1): elif msg.type == "text": yiri_msg_list.append(platform_message.Plain(text=msg.data["text"])) elif msg.type == "image": - yiri_msg_list.append(platform_message.Image(url=msg.data["url"])) + image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) + yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}")) chain = platform_message.MessageChain(yiri_msg_list) @@ -99,9 +101,9 @@ def target2yiri(message: str, message_id: int = -1): class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod - def yiri2target(event: platform_events.Event, bot_account_id: int): + async def yiri2target(event: platform_events.Event, bot_account_id: int): - msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain) + msg, msg_id, msg_time = await AiocqhttpMessageConverter.yiri2target(event.message_chain) if type(event) is platform_events.GroupMessage: role = "member" @@ -164,8 +166,8 @@ def yiri2target(event: platform_events.Event, bot_account_id: int): return aiocqhttp.Event.from_payload(payload) @staticmethod - def target2yiri(event: aiocqhttp.Event): - yiri_chain = AiocqhttpMessageConverter.target2yiri( + async def target2yiri(event: aiocqhttp.Event): + yiri_chain = await AiocqhttpMessageConverter.target2yiri( event.message, event.message_id ) @@ -242,7 +244,7 @@ async def shutdown_trigger_placeholder(): async def send_message( self, target_type: str, target_id: str, message: platform_message.MessageChain ): - aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] + aiocq_msg = await AiocqhttpMessageConverter.yiri2target(message)[0] if target_type == "group": await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) @@ -255,8 +257,8 @@ async def reply_message( message: platform_message.MessageChain, quote_origin: bool = False, ): - aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) - aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] + aiocq_event = await AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) + aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] if quote_origin: aiocq_msg = aiocqhttp.MessageSegment.reply(aiocq_event.message_id) + aiocq_msg @@ -276,7 +278,7 @@ def register_listener( async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id try: - return await callback(self.event_converter.target2yiri(event), self) + return await callback(await self.event_converter.target2yiri(event), self) except: traceback.print_exc() diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index a1f5df8d..dce55fd5 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -38,6 +38,8 @@ class ContentElement(pydantic.BaseModel): image_url: typing.Optional[ImageURLContentObject] = None + image_base64: typing.Optional[str] = None + def __str__(self): if self.type == 'text': return self.text @@ -53,6 +55,10 @@ def from_text(cls, text: str): @classmethod def from_image_url(cls, image_url: str): return cls(type='image_url', image_url=ImageURLContentObject(url=image_url)) + + @classmethod + def from_image_base64(cls, image_base64: str): + return cls(type='image_base64', image_base64=image_base64) class Message(pydantic.BaseModel): diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 930cf9e7..cf29bd4d 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -48,6 +48,7 @@ async def preprocess( @abc.abstractmethod async def call( self, + query: core_entities.Query, model: modelmgr_entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, diff --git a/pkg/provider/modelmgr/requesters/anthropicmsgs.py b/pkg/provider/modelmgr/requesters/anthropicmsgs.py index cf2fa2da..25b7da17 100644 --- a/pkg/provider/modelmgr/requesters/anthropicmsgs.py +++ b/pkg/provider/modelmgr/requesters/anthropicmsgs.py @@ -2,6 +2,7 @@ import typing import traceback +import base64 import anthropic import httpx @@ -39,6 +40,7 @@ async def initialize(self): async def call( self, + query: core_entities.Query, model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, @@ -70,24 +72,20 @@ async def call( if isinstance(m.content, str) and m.content.strip() != "": req_messages.append(m.dict(exclude_none=True)) elif isinstance(m.content, list): - # m.content = [ - # c for c in m.content if c.type == "text" - # ] - - # if len(m.content) > 0: - # req_messages.append(m.dict(exclude_none=True)) msg_dict = m.dict(exclude_none=True) for i, ce in enumerate(m.content): - if ce.type == "image_url": - base64_image, image_format = await image.qq_image_url_to_base64(ce.image_url.url) + + if ce.type == "image_base64": + image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) + alter_image_ele = { "type": "image", "source": { "type": "base64", "media_type": f"image/{image_format}", - "data": base64_image + "data": image_b64 } } msg_dict["content"][i] = alter_image_ele diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index da993ade..76af459f 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -65,6 +65,7 @@ async def _make_msg( async def _closure( self, + query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, use_funcs: list[tools_entities.LLMFunction] = None, @@ -87,8 +88,12 @@ async def _closure( for msg in messages: if 'content' in msg and isinstance(msg["content"], list): for me in msg["content"]: - if me["type"] == "image_url": - me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url']) + if me["type"] == "image_base64": + me["image_url"] = { + "url": me["image_base64"] + } + me["type"] = "image_url" + del me["image_base64"] args["messages"] = messages @@ -102,6 +107,7 @@ async def _closure( async def call( self, + query: core_entities.Query, model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, @@ -118,7 +124,7 @@ async def call( req_messages.append(msg_dict) try: - return await self._closure(req_messages, model, funcs) + return await self._closure(query, req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: @@ -134,11 +140,3 @@ async def call( raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}') except openai.APIError as e: raise errors.RequesterError(f'请求错误: {e.message}') - - @async_lru.alru_cache(maxsize=128) - async def get_base64_str( - self, - original_url: str, - ) -> str: - base64_image, image_format = await image.qq_image_url_to_base64(original_url) - return f"data:image/{image_format};base64,{base64_image}" diff --git a/pkg/provider/modelmgr/requesters/ollamachat.py b/pkg/provider/modelmgr/requesters/ollamachat.py index d7ec9614..ba2ee66a 100644 --- a/pkg/provider/modelmgr/requesters/ollamachat.py +++ b/pkg/provider/modelmgr/requesters/ollamachat.py @@ -6,6 +6,7 @@ from typing import Union, Mapping, Any, AsyncIterator import uuid import json +import base64 import async_lru import ollama @@ -13,7 +14,7 @@ from .. import entities, errors, requester from ... import entities as llm_entities from ...tools import entities as tools_entities -from ....core import app +from ....core import app, entities as core_entities from ....utils import image REQUESTER_NAME: str = "ollama-chat" @@ -43,7 +44,7 @@ async def _req(self, **args ) - async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelInfo, + async def _closure(self, query: core_entities.Query, req_messages: list[dict], use_model: entities.LLMModelInfo, user_funcs: list[tools_entities.LLMFunction] = None) -> ( llm_entities.Message): args: Any = self.request_cfg['args'].copy() @@ -57,9 +58,9 @@ async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelI for me in msg["content"]: if me["type"] == "text": text_content.append(me["text"]) - elif me["type"] == "image_url": - image_url = await self.get_base64_str(me["image_url"]['url']) - image_urls.append(image_url) + elif me["type"] == "image_base64": + image_urls.append(me["image_base64"]) + msg["content"] = "\n".join(text_content) msg["images"] = [url.split(',')[1] for url in image_urls] if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict @@ -109,6 +110,7 @@ async def _make_msg( async def call( self, + query: core_entities.Query, model: entities.LLMModelInfo, messages: typing.List[llm_entities.Message], funcs: typing.List[tools_entities.LLMFunction] = None, @@ -122,14 +124,6 @@ async def call( msg_dict["content"] = "\n".join(part["text"] for part in content) req_messages.append(msg_dict) try: - return await self._closure(req_messages, model, funcs) + return await self._closure(query, req_messages, model, funcs) except asyncio.TimeoutError: raise errors.RequesterError('请求超时') - - @async_lru.alru_cache(maxsize=128) - async def get_base64_str( - self, - original_url: str, - ) -> str: - base64_image, image_format = await image.qq_image_url_to_base64(original_url) - return f"data:image/{image_format};base64,{base64_image}" diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index beb49115..733d6344 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -3,6 +3,7 @@ import typing import json import uuid +import base64 from .. import runner from ...core import entities as core_entities @@ -52,10 +53,9 @@ async def _preprocess_user_message( for ce in query.user_message.content: if ce.type == "text": plain_text += ce.text - elif ce.type == "image_url": - file_bytes, image_format = await image.get_qq_image_bytes( - ce.image_url.url - ) + elif ce.type == "image_base64": + image_b64, image_format = await image.extract_b64_and_format(ce.image_base64) + file_bytes = base64.b64decode(image_b64) file = ("img.png", file_bytes, f"image/{image_format}") file_upload_resp = await self.dify_client.upload_file( file, diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 84cda565..f05c82e3 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -23,7 +23,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] # 首次请求 - msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) yield msg @@ -61,7 +61,7 @@ async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_ent req_messages.append(err_msg) # 处理完所有调用,再次请求 - msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs) + msg = await query.use_model.requester.call(query, query.use_model, req_messages, query.use_funcs) yield msg diff --git a/pkg/utils/image.py b/pkg/utils/image.py index dc13802b..06885175 100644 --- a/pkg/utils/image.py +++ b/pkg/utils/image.py @@ -1,9 +1,11 @@ import base64 import typing +import io from urllib.parse import urlparse, parse_qs import ssl import aiohttp +import PIL.Image def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]: @@ -13,9 +15,10 @@ def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]: return f"http://{parsed.netloc}{parsed.path}", query -async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]: - """获取QQ图片的bytes""" - image_url, query = get_qq_image_downloadable_url(image_url) +async def get_qq_image_bytes(image_url: str, query: dict={}) -> tuple[bytes, str]: + """[弃用]获取QQ图片的bytes""" + image_url, query_in_url = get_qq_image_downloadable_url(image_url) + query = {**query, **query_in_url} ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -24,8 +27,11 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]: resp.raise_for_status() file_bytes = await resp.read() content_type = resp.headers.get('Content-Type') - if not content_type or not content_type.startswith('image/'): + if not content_type: image_format = 'jpeg' + elif not content_type.startswith('image/'): + pil_img = PIL.Image.open(io.BytesIO(file_bytes)) + image_format = pil_img.format.lower() else: image_format = content_type.split('/')[-1] return file_bytes, image_format @@ -34,7 +40,7 @@ async def get_qq_image_bytes(image_url: str) -> tuple[bytes, str]: async def qq_image_url_to_base64( image_url: str ) -> typing.Tuple[str, str]: - """将QQ图片URL转为base64,并返回图片格式 + """[弃用]将QQ图片URL转为base64,并返回图片格式 Args: image_url (str): QQ图片URL @@ -47,8 +53,18 @@ async def qq_image_url_to_base64( # Flatten the query dictionary query = {k: v[0] for k, v in query.items()} - file_bytes, image_format = await get_qq_image_bytes(image_url) + file_bytes, image_format = await get_qq_image_bytes(image_url, query) base64_str = base64.b64encode(file_bytes).decode() return base64_str, image_format + +async def extract_b64_and_format(image_base64_data: str) -> typing.Tuple[str, str]: + """提取base64编码和图片格式 + + data:image/jpeg;base64,xxx + 提取出base64编码和图片格式 + """ + base64_str = image_base64_data.split(',')[-1] + image_format = image_base64_data.split(':')[-1].split(';')[0].split('/')[-1] + return base64_str, image_format \ No newline at end of file