diff --git a/libertai_agents/libertai_agents/agents.py b/libertai_agents/libertai_agents/agents.py index 2ac94c7..e36079f 100644 --- a/libertai_agents/libertai_agents/agents.py +++ b/libertai_agents/libertai_agents/agents.py @@ -179,7 +179,7 @@ async def __dump_api_generate_streamed_answer( async for message in self.generate_answer( messages, only_final_answer=only_final_answer ): - yield json.dumps(message.dict(), indent=4) + yield json.dumps(message.model_dump(), indent=4) async def __call_model(self, session: ClientSession, prompt: str) -> str | None: """ @@ -189,9 +189,11 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None: :param prompt: Prompt to give to the model :return: String response (if no error) """ - params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.dict()) + params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump()) - async with session.post(self.model.vm_url, json=params.dict()) as response: + async with session.post( + self.model.vm_url, json=params.model_dump() + ) as response: # TODO: handle errors and retries if response.status == HTTPStatus.OK: response_data = await response.json() diff --git a/libertai_agents/libertai_agents/models/base.py b/libertai_agents/libertai_agents/models/base.py index 4b721d1..29c0c94 100644 --- a/libertai_agents/libertai_agents/models/base.py +++ b/libertai_agents/libertai_agents/models/base.py @@ -74,7 +74,7 @@ def generate_prompt( if self.include_system_message and system_prompt is not None else [] ) - raw_messages = [x.dict() for x in messages] + raw_messages = [x.model_dump() for x in messages] for i in range(len(raw_messages)): included_messages: list = system_messages + raw_messages[i:] diff --git a/libertai_agents/libertai_agents/models/models.py b/libertai_agents/libertai_agents/models/models.py index 97bfbfa..2e41709 100644 --- a/libertai_agents/libertai_agents/models/models.py +++ b/libertai_agents/libertai_agents/models/models.py @@ -62,5 +62,5 @@ def get_model( ) return full_config.constructor( - model_id=model_id, **configuration.dict(exclude={"constructor"}) + model_id=model_id, **configuration.model_dump(exclude={"constructor"}) ) diff --git a/libertai_agents/tests/conftest.py b/libertai_agents/tests/conftest.py new file mode 100644 index 0000000..ec7fead --- /dev/null +++ b/libertai_agents/tests/conftest.py @@ -0,0 +1 @@ +from tests.fixtures.fixtures_tools import * # noqa: F401, F403 diff --git a/libertai_agents/tests/fixtures/__init__.py b/libertai_agents/tests/fixtures/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/libertai_agents/tests/fixtures/__init__.py @@ -0,0 +1 @@ + diff --git a/libertai_agents/tests/fixtures/fixtures_tools.py b/libertai_agents/tests/fixtures/fixtures_tools.py new file mode 100644 index 0000000..8226100 --- /dev/null +++ b/libertai_agents/tests/fixtures/fixtures_tools.py @@ -0,0 +1,20 @@ +from typing import Callable + +import pytest + + +@pytest.fixture() +def basic_function_for_tool() -> Callable: + def get_current_temperature(location: str, unit: str) -> float: + """ + Get the current temperature at a location. + + Args: + location: The location to get the temperature for, in the format "City, Country" + unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"]) + Returns: + The current temperature at the specified location in the specified units, as a float. + """ + return 22.0 # A real function should probably actually get the temperature! + + return get_current_temperature diff --git a/libertai_agents/tests/test_agents.py b/libertai_agents/tests/test_agents.py new file mode 100644 index 0000000..f8dda51 --- /dev/null +++ b/libertai_agents/tests/test_agents.py @@ -0,0 +1,35 @@ +from fastapi import FastAPI + +from libertai_agents.agents import ChatAgent +from libertai_agents.interfaces.tools import Tool +from libertai_agents.models import get_model +from libertai_agents.models.base import ModelId +from libertai_agents.models.models import ModelConfiguration + +MODEL_ID: ModelId = "NousResearch/Hermes-3-Llama-3.1-8B" + + +def test_create_chat_agent_minimal(): + agent = ChatAgent(model=get_model(MODEL_ID)) + + assert len(agent.tools) == 0 + assert agent.model.model_id == MODEL_ID + assert isinstance(agent.app, FastAPI) + + +def test_create_chat_agent_with_config(basic_function_for_tool): + context_length = 42 + + agent = ChatAgent( + model=get_model( + MODEL_ID, + custom_configuration=ModelConfiguration( + vm_url="https://example.org", context_length=context_length + ), + ), + tools=[Tool.from_function(basic_function_for_tool)], + expose_api=False, + ) + assert agent.model.context_length == context_length + assert not hasattr(agent, "app") + assert len(agent.tools) == 1 diff --git a/libertai_agents/tests/test_models.py b/libertai_agents/tests/test_models.py new file mode 100644 index 0000000..78b267d --- /dev/null +++ b/libertai_agents/tests/test_models.py @@ -0,0 +1,14 @@ +import pytest + +from libertai_agents.models import Model, get_model + + +def test_get_model_basic(): + model = get_model("NousResearch/Hermes-3-Llama-3.1-8B") + + assert isinstance(model, Model) + + +def test_get_model_invalid_id(): + with pytest.raises(ValueError): + _model = get_model(model_id="random-string") # type: ignore diff --git a/libertai_agents/tests/tools/test_function_tools.py b/libertai_agents/tests/tools/test_function_tools.py new file mode 100644 index 0000000..f043378 --- /dev/null +++ b/libertai_agents/tests/tools/test_function_tools.py @@ -0,0 +1,9 @@ +from libertai_agents.interfaces.tools import Tool + + +def test_function_example_tool(basic_function_for_tool): + libertai_tool = Tool.from_function(basic_function_for_tool) + assert libertai_tool.name == basic_function_for_tool.__name__ + + +# TODO: add test with Python 3.10+ union style when https://github.com/huggingface/transformers/pull/35103 merged + new release