Skip to content

Commit

Permalink
test: Basic agents, function tools and models checks
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Dec 6, 2024
1 parent 52d0d2a commit 885c477
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 5 deletions.
8 changes: 5 additions & 3 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def __dump_api_generate_streamed_answer(
async for message in self.generate_answer(
messages, only_final_answer=only_final_answer
):
yield json.dumps(message.dict(), indent=4)
yield json.dumps(message.model_dump(), indent=4)

Check warning on line 182 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L182

Added line #L182 was not covered by tests

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
"""
Expand All @@ -189,9 +189,11 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
:param prompt: Prompt to give to the model
:return: String response (if no error)
"""
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.dict())
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump())

Check warning on line 192 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L192

Added line #L192 was not covered by tests

async with session.post(self.model.vm_url, json=params.dict()) as response:
async with session.post(

Check warning on line 194 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L194

Added line #L194 was not covered by tests
self.model.vm_url, json=params.model_dump()
) as response:
# TODO: handle errors and retries
if response.status == HTTPStatus.OK:
response_data = await response.json()
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def generate_prompt(
if self.include_system_message and system_prompt is not None
else []
)
raw_messages = [x.dict() for x in messages]
raw_messages = [x.model_dump() for x in messages]

Check warning on line 77 in libertai_agents/libertai_agents/models/base.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/models/base.py#L77

Added line #L77 was not covered by tests

for i in range(len(raw_messages)):
included_messages: list = system_messages + raw_messages[i:]
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def get_model(
)

return full_config.constructor(
model_id=model_id, **configuration.dict(exclude={"constructor"})
model_id=model_id, **configuration.model_dump(exclude={"constructor"})
)
1 change: 1 addition & 0 deletions libertai_agents/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.fixtures.fixtures_tools import * # noqa: F401, F403
1 change: 1 addition & 0 deletions libertai_agents/tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

20 changes: 20 additions & 0 deletions libertai_agents/tests/fixtures/fixtures_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Callable

import pytest


@pytest.fixture()
def basic_function_for_tool() -> Callable:
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.
Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22.0 # A real function should probably actually get the temperature!

return get_current_temperature
35 changes: 35 additions & 0 deletions libertai_agents/tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from fastapi import FastAPI

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces.tools import Tool
from libertai_agents.models import get_model
from libertai_agents.models.base import ModelId
from libertai_agents.models.models import ModelConfiguration

MODEL_ID: ModelId = "NousResearch/Hermes-3-Llama-3.1-8B"


def test_create_chat_agent_minimal():
agent = ChatAgent(model=get_model(MODEL_ID))

assert len(agent.tools) == 0
assert agent.model.model_id == MODEL_ID
assert isinstance(agent.app, FastAPI)


def test_create_chat_agent_with_config(basic_function_for_tool):
context_length = 42

agent = ChatAgent(
model=get_model(
MODEL_ID,
custom_configuration=ModelConfiguration(
vm_url="https://example.org", context_length=context_length
),
),
tools=[Tool.from_function(basic_function_for_tool)],
expose_api=False,
)
assert agent.model.context_length == context_length
assert not hasattr(agent, "app")
assert len(agent.tools) == 1
14 changes: 14 additions & 0 deletions libertai_agents/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from libertai_agents.models import Model, get_model


def test_get_model_basic():
model = get_model("NousResearch/Hermes-3-Llama-3.1-8B")

assert isinstance(model, Model)


def test_get_model_invalid_id():
with pytest.raises(ValueError):
_model = get_model(model_id="random-string") # type: ignore
9 changes: 9 additions & 0 deletions libertai_agents/tests/tools/test_function_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from libertai_agents.interfaces.tools import Tool


def test_function_example_tool(basic_function_for_tool):
libertai_tool = Tool.from_function(basic_function_for_tool)
assert libertai_tool.name == basic_function_for_tool.__name__


# TODO: add test with Python 3.10+ union style when https://github.com/huggingface/transformers/pull/35103 merged + new release

0 comments on commit 885c477

Please sign in to comment.