Skip to content

Commit

Permalink
feat: 为 ollama 支持视觉和函数调用 (#950)
Browse files Browse the repository at this point in the history
  • Loading branch information
RockChinQ authored Dec 15, 2024
1 parent 9e7d9a9 commit 736f8b6
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions pkg/provider/modelmgr/requesters/ollamachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import typing
from typing import Union, Mapping, Any, AsyncIterator
import uuid
import json

import async_lru
import ollama
Expand Down Expand Up @@ -60,21 +62,49 @@ async def _closure(self, req_messages: list[dict], use_model: entities.LLMModelI
image_urls.append(image_url)
msg["content"] = "\n".join(text_content)
msg["images"] = [url.split(',')[1] for url in image_urls]
if 'tool_calls' in msg: # LangBot 内部以 str 存储 tool_calls 的参数,这里需要转换为 dict
for tool_call in msg['tool_calls']:
tool_call['function']['arguments'] = json.loads(tool_call['function']['arguments'])
args["messages"] = messages

resp: Mapping[str, Any] | AsyncIterator[Mapping[str, Any]] = await self._req(args)
args["tools"] = []
if user_funcs:
tools = await self.ap.tool_mgr.generate_tools_for_openai(user_funcs)
if tools:
args["tools"] = tools

resp = await self._req(args)
message: llm_entities.Message = await self._make_msg(resp)
return message

async def _make_msg(
self,
chat_completions: Union[Mapping[str, Any], AsyncIterator[Mapping[str, Any]]]) -> llm_entities.Message:
message: Any = chat_completions.pop('message', None)
chat_completions: ollama.ChatResponse) -> llm_entities.Message:
message: ollama.Message = chat_completions.message
if message is None:
raise ValueError("chat_completions must contain a 'message' field")

message.update(chat_completions)
ret_msg: llm_entities.Message = llm_entities.Message(**message)
ret_msg: llm_entities.Message = None

if message.content is not None:
ret_msg = llm_entities.Message(
role="assistant",
content=message.content
)
if message.tool_calls is not None and len(message.tool_calls) > 0:
tool_calls: list[llm_entities.ToolCall] = []

for tool_call in message.tool_calls:
tool_calls.append(llm_entities.ToolCall(
id=uuid.uuid4().hex,
type="function",
function=llm_entities.FunctionCall(
name=tool_call.function.name,
arguments=json.dumps(tool_call.function.arguments)
)
))
ret_msg.tool_calls = tool_calls

return ret_msg

async def call(
Expand All @@ -92,7 +122,7 @@ async def call(
msg_dict["content"] = "\n".join(part["text"] for part in content)
req_messages.append(msg_dict)
try:
return await self._closure(req_messages, model)
return await self._closure(req_messages, model, funcs)
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')

Expand Down

0 comments on commit 736f8b6

Please sign in to comment.