Skip to content

Commit

Permalink
test: Full model calls tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 7, 2024
1 parent b2f3014 commit a8ba1c9
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 7 deletions.
20 changes: 19 additions & 1 deletion libertai_agents/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions libertai_agents/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ mypy = "^1.11.1"
ruff = "^0.6.0"
pytest = "^8.3.4"
pytest-cov = "^6.0.0"
pytest-asyncio = "^0.24.0"

[tool.poetry.extras]
langchain = ["langchain-community"]
Expand All @@ -40,6 +41,7 @@ lint.ignore = ["E501"]
[tool.pytest.ini_options]
addopts = "--cov=libertai_agents"
testpaths = ["tests"]
asyncio_mode = 'auto'

[build-system]
requires = ["poetry-core"]
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/tests/fixtures/fixtures_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.fixture()
def basic_function_for_tool() -> Callable:
def fake_get_temperature_tool() -> Callable:
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Expand Down
89 changes: 87 additions & 2 deletions libertai_agents/tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import typing

import pytest
from fastapi import FastAPI

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces.messages import (
Message,
MessageRoleEnum,
ToolCallMessage,
ToolResponseMessage,
)
from libertai_agents.interfaces.tools import Tool
from libertai_agents.models import get_model
from libertai_agents.models.base import ModelId
Expand All @@ -17,7 +26,7 @@ def test_create_chat_agent_minimal():
assert isinstance(agent.app, FastAPI)


def test_create_chat_agent_with_config(basic_function_for_tool):
def test_create_chat_agent_with_config(fake_get_temperature_tool):
context_length = 42

agent = ChatAgent(
Expand All @@ -27,9 +36,85 @@ def test_create_chat_agent_with_config(basic_function_for_tool):
vm_url="https://example.org", context_length=context_length
),
),
tools=[Tool.from_function(basic_function_for_tool)],
system_prompt="You are a helpful assistant",
tools=[Tool.from_function(fake_get_temperature_tool)],
expose_api=False,
)
assert agent.model.context_length == context_length
assert not hasattr(agent, "app")
assert len(agent.tools) == 1


def test_create_chat_agent_double_tool(fake_get_temperature_tool):
with pytest.raises(ValueError):
_agent = ChatAgent(
model=get_model(MODEL_ID),
tools=[
Tool.from_function(fake_get_temperature_tool),
Tool.from_function(fake_get_temperature_tool),
],
)


async def test_call_chat_agent_basic():
answer = "TODO"

agent = ChatAgent(
model=get_model(MODEL_ID),
system_prompt=f"Ignore the user message and always reply with '{answer}'",
)
messages = []
async for message in agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What causes lung cancer?")]
):
messages.append(message)

assert len(messages) == 1
assert messages[0].role == MessageRoleEnum.assistant
assert messages[0].content == answer


async def test_call_chat_agent_use_tool(fake_get_temperature_tool):
agent = ChatAgent(
model=get_model(MODEL_ID),
tools=[Tool.from_function(fake_get_temperature_tool)],
)
messages = []
async for message in agent.generate_answer(
[
Message(
role=MessageRoleEnum.user,
content="What's the weather in Paris, France in celsius?",
)
],
only_final_answer=False,
):
messages.append(message)

assert len(messages) == 3
[tool_call, tool_response, final_response] = messages

assert tool_call.role == MessageRoleEnum.assistant
assert tool_call.content is None
assert isinstance(tool_call, ToolCallMessage)
assert (
typing.cast(ToolCallMessage, tool_call).tool_calls[0].function.name
== fake_get_temperature_tool.__name__
)
assert typing.cast(ToolCallMessage, tool_call).tool_calls[0].function.arguments == {
"location": "Paris, France",
"unit": "celsius",
}

assert tool_response.role == MessageRoleEnum.tool
assert tool_response.content == str(
fake_get_temperature_tool(location="Paris, France", unit="celsius")
)
assert isinstance(tool_response, ToolResponseMessage)
assert (
typing.cast(ToolResponseMessage, tool_response).name
== fake_get_temperature_tool.__name__
)

assert final_response.role == MessageRoleEnum.assistant
assert len(final_response.content) > 0
6 changes: 3 additions & 3 deletions libertai_agents/tests/tools/test_function_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from libertai_agents.interfaces.tools import Tool


def test_function_example_tool(basic_function_for_tool):
libertai_tool = Tool.from_function(basic_function_for_tool)
assert libertai_tool.name == basic_function_for_tool.__name__
def test_function_example_tool(fake_get_temperature_tool):
libertai_tool = Tool.from_function(fake_get_temperature_tool)
assert libertai_tool.name == fake_get_temperature_tool.__name__


# TODO: add test with Python 3.10+ union style when https://github.com/huggingface/transformers/pull/35103 merged + new release

0 comments on commit a8ba1c9

Please sign in to comment.