Skip to content

Commit

Permalink
Fix parallel tool calls (#2779)
Browse files Browse the repository at this point in the history
* Fix parallel tool calls

* remove comments
  • Loading branch information
Weves authored Oct 13, 2024
1 parent 3365e0b commit 26bdb41
Show file tree
Hide file tree
Showing 2 changed files with 293 additions and 2 deletions.
5 changes: 3 additions & 2 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"arguments": json.dumps(tool_call["args"]),
},
"type": "function",
"index": 0, # only support a single tool call atm
"index": tool_call.get("index", 0),
}
for tool_call in message.tool_calls
]
Expand Down Expand Up @@ -158,12 +158,13 @@ def _convert_delta_to_message_chunk(
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
idx = tool_call.index

tool_call_chunk = ToolCallChunk(
name=tool_name,
id=tool_call.id,
args=tool_call.function.arguments,
index=0, # only support a single tool call atm
index=idx,
)

return AIMessageChunk(
Expand Down
290 changes: 290 additions & 0 deletions backend/tests/unit/danswer/llm/test_chat_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
from unittest.mock import patch

import litellm
import pytest
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import HumanMessage
from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction

from danswer.llm.chat_llm import DefaultMultiLLM


def _create_delta(
role: str | None = None,
content: str | None = None,
tool_calls: list[ChatCompletionDeltaToolCall] | None = None,
) -> Delta:
delta = Delta(role=role, content=content)
# NOTE: for some reason, if you pass tool_calls to the constructor, it doesn't actually
# get set, so we have to do it this way
delta.tool_calls = tool_calls
return delta


@pytest.fixture
def default_multi_llm() -> DefaultMultiLLM:
return DefaultMultiLLM(
api_key="test_key",
timeout=30,
model_provider="openai",
model_name="gpt-3.5-turbo",
)


def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
# Mock the litellm.completion function
with patch("danswer.llm.chat_llm.litellm.completion") as mock_completion:
# Create a mock response with multiple tool calls using litellm objects
mock_response = litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
finish_reason="tool_calls",
index=0,
message=litellm.Message(
content=None,
role="assistant",
tool_calls=[
litellm.ChatCompletionMessageToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather",
arguments='{"location": "New York"}',
),
type="function",
),
litellm.ChatCompletionMessageToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time", arguments='{"timezone": "EST"}'
),
type="function",
),
],
),
)
],
model="gpt-3.5-turbo",
usage=litellm.Usage(
prompt_tokens=50, completion_tokens=30, total_tokens=80
),
)
mock_completion.return_value = mock_response

# Define input messages
messages = [HumanMessage(content="What's the weather and time in New York?")]

# Define available tools
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get the current time for a timezone",
"parameters": {
"type": "object",
"properties": {"timezone": {"type": "string"}},
"required": ["timezone"],
},
},
},
]

# Call the _invoke_implementation method
result = default_multi_llm.invoke(messages, tools)

# Assert that the result is an AIMessage
assert isinstance(result, AIMessage)

# Assert that the content is None (as per the mock response)
assert result.content == ""

# Assert that there are two tool calls
assert len(result.tool_calls) == 2

# Assert the details of the first tool call
assert result.tool_calls[0]["id"] == "call_1"
assert result.tool_calls[0]["name"] == "get_weather"
assert result.tool_calls[0]["args"] == {"location": "New York"}

# Assert the details of the second tool call
assert result.tool_calls[1]["id"] == "call_2"
assert result.tool_calls[1]["name"] == "get_time"
assert result.tool_calls[1]["args"] == {"timezone": "EST"}

# Verify that litellm.completion was called with the correct arguments
mock_completion.assert_called_once_with(
model="openai/gpt-3.5-turbo",
api_key="test_key",
base_url=None,
api_version=None,
custom_llm_provider=None,
messages=[
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=False,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
)


def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> None:
# Mock the litellm.completion function
with patch("danswer.llm.chat_llm.litellm.completion") as mock_completion:
# Create a mock response with multiple tool calls using litellm objects
mock_response = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
id="call_1",
function=LiteLLMFunction(
name="get_weather", arguments='{"location": '
),
type="function",
index=0,
)
],
),
finish_reason=None,
index=0,
)
],
model="gpt-3.5-turbo",
),
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(
tool_calls=[
ChatCompletionDeltaToolCall(
id="",
function=LiteLLMFunction(arguments='"New York"}'),
type="function",
index=0,
)
]
),
finish_reason=None,
index=0,
)
],
model="gpt-3.5-turbo",
),
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(
tool_calls=[
ChatCompletionDeltaToolCall(
id="call_2",
function=LiteLLMFunction(
name="get_time", arguments='{"timezone": "EST"}'
),
type="function",
index=1,
)
]
),
finish_reason="tool_calls",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
mock_completion.return_value = mock_response

# Define input messages and tools (same as in the non-streaming test)
messages = [HumanMessage(content="What's the weather and time in New York?")]

tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get the current time for a timezone",
"parameters": {
"type": "object",
"properties": {"timezone": {"type": "string"}},
"required": ["timezone"],
},
},
},
]

# Call the stream method
stream_result = list(default_multi_llm.stream(messages, tools))

# Assert that we received the correct number of chunks
assert len(stream_result) == 3

# Combine all chunks into a single AIMessage
combined_result: AIMessage = AIMessageChunk(content="")
for chunk in stream_result:
combined_result += chunk # type: ignore

# Assert that the combined result matches our expectations
assert isinstance(combined_result, AIMessage)
assert combined_result.content == ""
assert len(combined_result.tool_calls) == 2
assert combined_result.tool_calls[0]["id"] == "call_1"
assert combined_result.tool_calls[0]["name"] == "get_weather"
assert combined_result.tool_calls[0]["args"] == {"location": "New York"}
assert combined_result.tool_calls[1]["id"] == "call_2"
assert combined_result.tool_calls[1]["name"] == "get_time"
assert combined_result.tool_calls[1]["args"] == {"timezone": "EST"}

# Verify that litellm.completion was called with the correct arguments
mock_completion.assert_called_once_with(
model="openai/gpt-3.5-turbo",
api_key="test_key",
base_url=None,
api_version=None,
custom_llm_provider=None,
messages=[
{"role": "user", "content": "What's the weather and time in New York?"}
],
tools=tools,
tool_choice=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
)

0 comments on commit 26bdb41

Please sign in to comment.