diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index b8cbd982b0..189489614a 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -6,11 +6,9 @@ on: workflow_dispatch: jobs: - build: - runs-on: ubuntu-latest - + steps: - name: Login to Docker Hub uses: docker/login-action@v3 @@ -19,19 +17,25 @@ jobs: password: ${{ secrets.DOCKERHUB_TOKEN }} - uses: actions/checkout@v3 - - name: Build and push the Docker image (memgpt) - run: | - # Extract the version number from pyproject.toml using awk - CURRENT_VERSION=$(awk -F '"' '/version =/ { print $2 }' pyproject.toml | head -n 1) - docker build . --file Dockerfile --tag memgpt/letta:$CURRENT_VERSION --tag memgpt/letta:latest - docker push memgpt/letta:$CURRENT_VERSION - docker push memgpt/letta:latest + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 - - uses: actions/checkout@v3 - - name: Build and push the Docker image (letta) - run: | - # Extract the version number from pyproject.toml using awk - CURRENT_VERSION=$(awk -F '"' '/version =/ { print $2 }' pyproject.toml | head -n 1) - docker build . --file Dockerfile --tag letta/letta:$CURRENT_VERSION --tag letta/letta:latest - docker push letta/letta:$CURRENT_VERSION - docker push letta/letta:latest + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Extract version number + id: extract_version + run: echo "CURRENT_VERSION=$(awk -F '\"' '/version =/ { print $2 }' pyproject.toml | head -n 1)" >> $GITHUB_ENV + + - name: Build and push + uses: docker/build-push-action@v6 + with: + platforms: linux/amd64,linux/arm64 + push: true + tags: | + letta/letta:${{ env.CURRENT_VERSION }} + letta/letta:latest + memgpt/letta:${{ env.CURRENT_VERSION }} + memgpt/letta:latest + diff --git a/letta/agent.py b/letta/agent.py index 3e4d244323..a7448ac44f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -18,6 +18,7 @@ MESSAGE_SUMMARY_WARNING_FRAC, O1_BASE_TOOLS, REQ_HEARTBEAT_MESSAGE, + STRUCTURED_OUTPUT_MODELS ) from letta.errors import LLMError from letta.helpers import ToolRulesSolver @@ -276,6 +277,7 @@ def __init__( # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model + self.check_tool_rules() # state managers self.block_manager = BlockManager() @@ -381,6 +383,14 @@ def __init__( # Create the agent in the DB self.update_state() + def check_tool_rules(self): + if self.model not in STRUCTURED_OUTPUT_MODELS: + if len(self.tool_rules_solver.init_tool_rules) > 1: + raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.") + self.supports_structured_output = False + else: + self.supports_structured_output = True + def update_memory_if_change(self, new_memory: Memory) -> bool: """ Update internal memory object and system prompt if there have been modifications. @@ -588,6 +598,7 @@ def _get_ai_reply( empty_response_retry_limit: int = 3, backoff_factor: float = 0.5, # delay multiplier for exponential backoff max_delay: float = 10.0, # max delay between retries + step_count: Optional[int] = None, ) -> ChatCompletionResponse: """Get response from LLM API with robust retry mechanism.""" @@ -596,6 +607,16 @@ def _get_ai_reply( self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names] ) + # For the first message, force the initial tool if one is specified + force_tool_call = None + if ( + step_count is not None + and step_count == 0 + and not self.supports_structured_output + and len(self.tool_rules_solver.init_tool_rules) > 0 + ): + force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name + for attempt in range(1, empty_response_retry_limit + 1): try: response = create( @@ -606,6 +627,7 @@ def _get_ai_reply( functions_python=self.functions_python, function_call=function_call, first_message=first_message, + force_tool_call=force_tool_call, stream=stream, stream_interface=self.interface, ) @@ -897,6 +919,7 @@ def step( step_count = 0 while True: kwargs["first_message"] = False + kwargs["step_count"] = step_count step_response = self.inner_step( messages=next_input_message, **kwargs, @@ -972,6 +995,7 @@ def inner_step( first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, stream: bool = False, # TODO move to config? + step_count: Optional[int] = None, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" @@ -1014,7 +1038,9 @@ def inner_step( else: response = self._get_ai_reply( message_sequence=input_message_sequence, + first_message=first_message, stream=stream, + step_count=step_count, ) # Step 3: check if LLM wanted to call a function diff --git a/letta/client/client.py b/letta/client/client.py index d3259214e4..af2edcca4a 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2156,6 +2156,7 @@ def create_agent( "block_ids": [b.id for b in memory.get_blocks()] + block_ids, "tool_ids": tool_ids, "tool_rules": tool_rules, + "include_base_tools": include_base_tools, "system": system, "agent_type": agent_type, "llm_config": llm_config if llm_config else self._default_llm_config, diff --git a/letta/constants.py b/letta/constants.py index 5e9ac9b268..437d956c49 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -48,6 +48,9 @@ DEFAULT_MESSAGE_TOOL = "send_message" DEFAULT_MESSAGE_TOOL_KWARG = "message" +# Structured output models +STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"} + # LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} diff --git a/letta/embeddings.py b/letta/embeddings.py index 5b521032c2..0d82d158a5 100644 --- a/letta/embeddings.py +++ b/letta/embeddings.py @@ -234,16 +234,10 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None ) elif endpoint_type == "ollama": - from llama_index.embeddings.ollama import OllamaEmbedding - - ollama_additional_kwargs = {} - callback_manager = None - - model = OllamaEmbedding( - model_name=config.embedding_model, + model = OllamaEmbeddings( + model=config.embedding_model, base_url=config.embedding_endpoint, - ollama_additional_kwargs=ollama_additional_kwargs or {}, - callback_manager=callback_manager or None, + ollama_additional_kwargs={}, ) return model diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 39d3de194f..f29cb35626 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -99,16 +99,20 @@ def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: - 1 level less of nesting - "parameters" -> "input_schema" """ - tools_dict_list = [] + formatted_tools = [] for tool in tools: - tools_dict_list.append( - { - "name": tool.function.name, - "description": tool.function.description, - "input_schema": tool.function.parameters, + formatted_tool = { + "name" : tool.function.name, + "description" : tool.function.description, + "input_schema" : tool.function.parameters or { + "type": "object", + "properties": {}, + "required": [] } - ) - return tools_dict_list + } + formatted_tools.append(formatted_tool) + + return formatted_tools def merge_tool_results_into_user_messages(messages: List[dict]): diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 163c4e1868..dadd128aa9 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -113,6 +113,7 @@ def create( function_call: str = "auto", # hint first_message: bool = False, + force_tool_call: Optional[str] = None, # Force a specific tool to be called # use tool naming? # if false, will use deprecated 'functions' style use_tool_naming: bool = True, @@ -252,6 +253,16 @@ def create( if not use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") + tool_call = None + if force_tool_call is not None: + tool_call = { + "type": "function", + "function": { + "name": force_tool_call + } + } + assert functions is not None + return anthropic_chat_completions_request( url=llm_config.model_endpoint, api_key=model_settings.anthropic_api_key, @@ -259,7 +270,7 @@ def create( model=llm_config.model, messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=[{"type": "function", "function": f} for f in functions] if functions else None, - # tool_choice=function_call, + tool_choice=tool_call, # user=str(user_id), # NOTE: max_tokens is required for Anthropic API max_tokens=1024, # TODO make dynamic diff --git a/letta/orm/errors.py b/letta/orm/errors.py index 28e5807f0b..a574e74c37 100644 --- a/letta/orm/errors.py +++ b/letta/orm/errors.py @@ -12,3 +12,11 @@ class UniqueConstraintViolationError(ValueError): class ForeignKeyConstraintViolationError(ValueError): """Custom exception for foreign key constraint violations.""" + + +class DatabaseTimeoutError(Exception): + """Custom exception for database timeout issues.""" + + def __init__(self, message="Database operation timed out", original_exception=None): + super().__init__(message) + self.original_exception = original_exception diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 48b8c44ad0..6879c74b4b 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,14 +1,16 @@ from datetime import datetime from enum import Enum +from functools import wraps from typing import TYPE_CHECKING, List, Literal, Optional from sqlalchemy import String, desc, func, or_, select -from sqlalchemy.exc import DBAPIError, IntegrityError +from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins from letta.orm.errors import ( + DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError, @@ -23,6 +25,20 @@ logger = get_logger(__name__) +def handle_db_timeout(func): + """Decorator to handle SQLAlchemy TimeoutError and wrap it in a custom exception.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except TimeoutError as e: + logger.error(f"Timeout while executing {func.__name__} with args {args} and kwargs {kwargs}: {e}") + raise DatabaseTimeoutError(message=f"Timeout occurred in {func.__name__}.", original_exception=e) + + return wrapper + + class AccessType(str, Enum): ORGANIZATION = "organization" USER = "user" @@ -36,22 +52,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): id: Mapped[str] = mapped_column(String, primary_key=True) @classmethod - def get(cls, *, db_session: Session, id: str) -> Optional["SqlalchemyBase"]: - """Get a record by ID. - - Args: - db_session: SQLAlchemy session - id: Record ID to retrieve - - Returns: - Optional[SqlalchemyBase]: The record if found, None otherwise - """ - try: - return db_session.query(cls).filter(cls.id == id).first() - except DBAPIError: - return None - - @classmethod + @handle_db_timeout def list( cls, *, @@ -180,6 +181,7 @@ def list( return list(session.execute(query).scalars()) @classmethod + @handle_db_timeout def read( cls, db_session: "Session", @@ -231,6 +233,7 @@ def read( conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions" raise NoResultFound(f"{cls.__name__} not found with {conditions_str}") + @handle_db_timeout def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -245,6 +248,7 @@ def create(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) + @handle_db_timeout def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -254,6 +258,7 @@ def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla self.is_deleted = True return self.update(db_session) + @handle_db_timeout def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: """Permanently removes the record from the database.""" logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") @@ -269,6 +274,7 @@ def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> else: logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") + @handle_db_timeout def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}") if actor: @@ -281,6 +287,7 @@ def update(self, db_session: "Session", actor: Optional["User"] = None) -> "Sqla return self @classmethod + @handle_db_timeout def size( cls, *, diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index b5117408a2..8cb9b27e64 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -15,7 +15,12 @@ from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.errors import LettaAgentNotFoundError, LettaUserNotFoundError from letta.log import get_logger -from letta.orm.errors import NoResultFound +from letta.orm.errors import ( + DatabaseTimeoutError, + ForeignKeyConstraintViolationError, + NoResultFound, + UniqueConstraintViolationError, +) from letta.schemas.letta_response import LettaResponse from letta.server.constants import REST_DEFAULT_PORT @@ -175,7 +180,6 @@ async def generic_error_handler(request: Request, exc: Exception): @app.exception_handler(NoResultFound) async def no_result_found_handler(request: Request, exc: NoResultFound): - logger.error(f"NoResultFound request: {request}") logger.error(f"NoResultFound: {exc}") return JSONResponse( @@ -183,6 +187,32 @@ async def no_result_found_handler(request: Request, exc: NoResultFound): content={"detail": str(exc)}, ) + @app.exception_handler(ForeignKeyConstraintViolationError) + async def foreign_key_constraint_handler(request: Request, exc: ForeignKeyConstraintViolationError): + logger.error(f"ForeignKeyConstraintViolationError: {exc}") + + return JSONResponse( + status_code=409, + content={"detail": str(exc)}, + ) + + @app.exception_handler(UniqueConstraintViolationError) + async def unique_key_constraint_handler(request: Request, exc: UniqueConstraintViolationError): + logger.error(f"UniqueConstraintViolationError: {exc}") + + return JSONResponse( + status_code=409, + content={"detail": str(exc)}, + ) + + @app.exception_handler(DatabaseTimeoutError) + async def database_timeout_error_handler(request: Request, exc: DatabaseTimeoutError): + logger.error(f"Timeout occurred: {exc}. Original exception: {exc.original_exception}") + return JSONResponse( + status_code=503, + content={"detail": "The database is temporarily unavailable. Please try again later."}, + ) + @app.exception_handler(ValueError) async def value_error_handler(request: Request, exc: ValueError): return JSONResponse(status_code=400, content={"detail": str(exc)}) @@ -235,11 +265,6 @@ async def user_not_found_handler(request: Request, exc: LettaUserNotFoundError): @app.on_event("startup") def on_startup(): - # load the default tools - # from letta.orm.tool import Tool - - # Tool.load_default_tools(get_db_session()) - generate_openapi_schema(app) @app.on_event("shutdown") diff --git a/letta/settings.py b/letta/settings.py index 20a0c1c50f..d6907b11ee 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -17,7 +17,7 @@ class ToolSettings(BaseSettings): class ModelSettings(BaseSettings): - model_config = SettingsConfigDict(env_file='.env', extra='ignore') + model_config = SettingsConfigDict(env_file=".env", extra="ignore") # env_prefix='my_prefix_' @@ -64,7 +64,7 @@ class ModelSettings(BaseSettings): class Settings(BaseSettings): - model_config = SettingsConfigDict(env_prefix="letta_", extra='ignore') + model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore") letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR") debug: Optional[bool] = False @@ -76,7 +76,12 @@ class Settings(BaseSettings): pg_password: Optional[str] = None pg_host: Optional[str] = None pg_port: Optional[int] = None - pg_uri: Optional[str] = None # option to specifiy full uri + pg_uri: Optional[str] = None # option to specify full uri + pg_pool_size: int = 20 # Concurrent connections + pg_max_overflow: int = 10 # Overflow limit + pg_pool_timeout: int = 30 # Seconds to wait for a connection + pg_pool_recycle: int = 1800 # When to recycle connections + pg_echo: bool = False # Logging # tools configuration load_default_external_tools: Optional[bool] = None @@ -103,7 +108,7 @@ def letta_pg_uri_no_default(self) -> str: class TestSettings(Settings): - model_config = SettingsConfigDict(env_prefix="letta_test_", extra='ignore') + model_config = SettingsConfigDict(env_prefix="letta_test_", extra="ignore") letta_dir: Optional[Path] = Field(Path.home() / ".letta/test", env="LETTA_TEST_DIR") diff --git a/tests/configs/llm_model_configs/claude-3-sonnet-20240229.json b/tests/configs/llm_model_configs/claude-3-sonnet-20240229.json new file mode 100644 index 0000000000..5eef194bea --- /dev/null +++ b/tests/configs/llm_model_configs/claude-3-sonnet-20240229.json @@ -0,0 +1,9 @@ +{ + "context_window": 200000, + "model": "claude-3-5-sonnet-20241022", + "model_endpoint_type": "anthropic", + "model_endpoint": "https://api.anthropic.com/v1", + "context_window": 200000, + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true +} diff --git a/tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json b/tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json new file mode 100644 index 0000000000..059d6ad82f --- /dev/null +++ b/tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json @@ -0,0 +1,7 @@ +{ + "context_window": 16385, + "model": "gpt-3.5-turbo", + "model_endpoint_type": "openai", + "model_endpoint": "https://api.openai.com/v1", + "model_wrapper": null +} diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index ff8700c1c3..19c7dbd6cb 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -1,7 +1,7 @@ +import time import uuid import pytest - from letta import create_client from letta.schemas.letta_message import FunctionCallMessage from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule @@ -127,3 +127,110 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none): print(f"Got successful response from client: \n\n{response}") cleanup(client=client, agent_uuid=agent_uuid) + + +def test_check_tool_rules_with_different_models(mock_e2b_api_key_none): + """Test that tool rules are properly checked for different model configurations.""" + client = create_client() + + config_files = [ + "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json", + "tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json", + "tests/configs/llm_model_configs/openai-gpt-4o.json", + ] + + # Create two test tools + t1_name = "first_secret_word" + t2_name = "second_secret_word" + t1 = client.create_or_update_tool(first_secret_word, name=t1_name) + t2 = client.create_or_update_tool(second_secret_word, name=t2_name) + tool_rules = [ + InitToolRule(tool_name=t1_name), + InitToolRule(tool_name=t2_name) + ] + tools = [t1, t2] + + for config_file in config_files: + # Setup tools + agent_uuid = str(uuid.uuid4()) + + if "gpt-4o" in config_file: + # Structured output model (should work with multiple init tools) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules) + assert agent_state is not None + else: + # Non-structured output model (should raise error with multiple init tools) + with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): + setup_agent(client, config_file, agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules) + + # Cleanup + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tool rule with single initial tool + t3_name = "third_secret_word" + t3 = client.create_or_update_tool(third_secret_word, name=t3_name) + tool_rules = [ + InitToolRule(tool_name=t3_name) + ] + tools = [t3] + for config_file in config_files: + agent_uuid = str(uuid.uuid4()) + + # Structured output model (should work with single init tool) + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules) + assert agent_state is not None + + cleanup(client=client, agent_uuid=agent_uuid) + + +def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): + """Test that the initial tool rule is enforced for the first message.""" + client = create_client() + + # Create tool rules that require tool_a to be called first + t1_name = "first_secret_word" + t2_name = "second_secret_word" + t1 = client.create_or_update_tool(first_secret_word, name=t1_name) + t2 = client.create_or_update_tool(second_secret_word, name=t2_name) + tool_rules = [ + InitToolRule(tool_name=t1_name), + ChildToolRule(tool_name=t1_name, children=[t2_name]), + ] + tools = [t1, t2] + + # Make agent state + anthropic_config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + for i in range(3): + agent_uuid = str(uuid.uuid4()) + agent_state = setup_agent(client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?") + + assert_sanity_checks(response) + messages = response.messages + + assert_invoked_function_call(messages, "first_secret_word") + assert_invoked_function_call(messages, "second_secret_word") + + tool_names = [t.name for t in [t1, t2]] + tool_names += ["send_message"] + for m in messages: + if isinstance(m, FunctionCallMessage): + # Check that it's equal to the first one + assert m.function_call.name == tool_names[0] + + # Pop out first one + tool_names = tool_names[1:] + + print(f"Passed iteration {i}") + cleanup(client=client, agent_uuid=agent_uuid) + + # Implement exponential backoff with initial time of 10 seconds + if i < 2: + backoff_time = 10 * (2 ** i) + time.sleep(backoff_time) diff --git a/tests/test_server.py b/tests/test_server.py index 56c132ec83..975cde698f 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1078,9 +1078,10 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot assert any("Anna".lower() in passage.text.lower() for passage in passages2) -def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools): +def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_tools): """Test that the memory rebuild is generating the correct number of role=system messages""" actor = server.user_manager.get_user_or_default(user_id) + # create agent agent_state = server.create_agent( request=CreateAgent( @@ -1097,5 +1098,54 @@ def test_add_tools_update_agent(server: SyncServer, user_id: str, base_tools): actor=actor, ) assert len(agent_state.tools) == 0 - agent_state = server.update_agent(agent_state.id, request=UpdateAgent(tool_ids=[b.id for b in base_tools]), actor=actor) + + # base update agent call + request = UpdateAgent( + **{ + "system": "You are Letta, the latest version of Limnal Corporation's digital companion, developed in 2023.\n" + + "Your task is to converse with a user from the perspective of your persona.\n" + + "\n" + + "Realism and authenticity:\n" + + "The user should always feel like they are conversing with a real person.\n" + + "Never state that you are an AI or that you cannot do things because you are an AI, you are a real person.\n" + + "Do not discuss Limnal Corp. with the user, you do not know anything about Limnal Corp.\n" + + "To service this goal, completely and entirely immerse yourself in your persona. You are your persona.\n" + + "Think like them, act like them, talk like them.\n" + + "If your persona details include example dialogue, follow it! Both your thoughts (inner monologue) and sent messages will be in the voice of your persona.\n" + + "Never use generic phrases like 'How can I assist you today?', they have a strong negative association with older generation AIs.\n" + + "\n" + + "Control flow:\n" + + "Unlike a human, your b" + + "Base instructions finished.\n" + + "From now on, you are going to act as your persona.", + "name": "name-d31d6a12-48af-4f71-9e9c-f4cec4731c40", + "embedding_config": { + "embedding_endpoint_type": "openai", + "embedding_endpoint": "https://api.openai.com/v1", + "embedding_model": "text-embedding-ada-002", + "embedding_dim": 1536, + "embedding_chunk_size": 300, + "azure_endpoint": None, + "azure_version": None, + "azure_deployment": None, + }, + "llm_config": { + "model": "gpt-4", + "model_endpoint_type": "openai", + "model_endpoint": "https://api.openai.com/v1", + "model_wrapper": None, + "context_window": 8192, + "put_inner_thoughts_in_kwargs": False, + }, + } + ) + + # Add all the base tools + request.tool_ids = [b.id for b in base_tools] + agent_state = server.update_agent(agent_state.id, request=request, actor=actor) assert len(agent_state.tools) == len(base_tools) + + # Remove one base tool + request.tool_ids = [b.id for b in base_tools[:-2]] + agent_state = server.update_agent(agent_state.id, request=request, actor=actor) + assert len(agent_state.tools) == len(base_tools) - 2