From a898a2e4cd21c0e47838b450be5b1b4e33c3ecd3 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Thu, 19 Dec 2024 13:24:24 -0800 Subject: [PATCH 01/11] Add multiple message getter --- letta/services/message_manager.py | 8 ++++++++ tests/test_managers.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 6c317f6498..7eedf95f09 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -28,6 +28,14 @@ def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[Py except NoResultFound: return None + @enforce_types + def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: + """Fetch a message by ID.""" + with self.session_maker() as session: + results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id) + + return [msg.to_pydantic() for msg in results] + @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: """Create a new message.""" diff --git a/tests/test_managers.py b/tests/test_managers.py index 37c6f2ac08..588f59ce4f 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -1565,6 +1565,15 @@ def create_test_messages(server: SyncServer, base_message: PydanticMessage, defa return messages +def test_get_messages_by_ids(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent): + """Test basic message listing with limit""" + messages = create_test_messages(server, hello_world_message_fixture, default_user) + message_ids = [m.id for m in messages] + + results = server.message_manager.get_messages_by_ids(message_ids=message_ids, actor=default_user) + assert sorted(message_ids) == sorted([r.id for r in results]) + + def test_message_listing_basic(server: SyncServer, hello_world_message_fixture, default_user, sarah_agent): """Test basic message listing with limit""" create_test_messages(server, hello_world_message_fixture, default_user) From 37100ec49576ffa103fcd5dfd07df1b41a684c16 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 10:56:36 -0800 Subject: [PATCH 02/11] Move initialize message to agent_manager create --- letta/agent.py | 608 ++---------------- letta/server/server.py | 9 - letta/services/agent_manager.py | 126 +++- .../services/helpers/agent_manager_helper.py | 135 +++- tests/test_managers.py | 64 +- 5 files changed, 375 insertions(+), 567 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index b61dce7641..6ed305a171 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,17 +1,15 @@ -import datetime import inspect import time import traceback import warnings from abc import ABC, abstractmethod -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union from letta.constants import ( BASE_TOOLS, CLI_WARNING_PREFIX, FIRST_MESSAGE_ATTEMPTS, FUNC_FAILED_HEARTBEAT_MESSAGE, - IN_CONTEXT_MEMORY_KEYWORD, LLM_MAX_TOKENS, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, @@ -55,8 +53,6 @@ from letta.streaming_interface import StreamingRefreshCLIInterface from letta.system import ( get_heartbeat, - get_initial_boot_messages, - get_login_event, get_token_limit_warning, package_function_response, package_summarize_message, @@ -65,162 +61,17 @@ from letta.utils import ( count_tokens, get_friendly_error_msg, - get_local_time, get_tool_call_id, get_utc_time, - is_utc_datetime, json_dumps, json_loads, parse_json, printd, united_diff, validate_function_response, - verify_first_message_correctness, ) -def compile_memory_metadata_block( - actor: PydanticUser, - agent_id: str, - memory_edit_timestamp: datetime.datetime, - agent_manager: Optional[AgentManager] = None, - message_manager: Optional[MessageManager] = None, -) -> str: - # Put the timestamp in the local timezone (mimicking get_local_time()) - timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip() - - # Create a metadata block of info so the agent knows about the metadata of out-of-context memories - memory_metadata_block = "\n".join( - [ - f"### Memory [last modified: {timestamp_str}]", - f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", - f"{agent_manager.passage_size(actor=actor, agent_id=agent_id) if agent_manager else 0} total memories you created are stored in archival memory (use functions to access them)", - "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", - ] - ) - return memory_metadata_block - - -def compile_system_message( - system_prompt: str, - agent_id: str, - in_context_memory: Memory, - in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? - actor: PydanticUser, - agent_manager: Optional[AgentManager] = None, - message_manager: Optional[MessageManager] = None, - user_defined_variables: Optional[dict] = None, - append_icm_if_missing: bool = True, - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", -) -> str: - """Prepare the final/full system message that will be fed into the LLM API - - The base system message may be templated, in which case we need to render the variables. - - The following are reserved variables: - - CORE_MEMORY: the in-context memory of the LLM - """ - - if user_defined_variables is not None: - # TODO eventually support the user defining their own variables to inject - raise NotImplementedError - else: - variables = {} - - # Add the protected memory variable - if IN_CONTEXT_MEMORY_KEYWORD in variables: - raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}") - else: - # TODO should this all put into the memory.__repr__ function? - memory_metadata_string = compile_memory_metadata_block( - actor=actor, - agent_id=agent_id, - memory_edit_timestamp=in_context_memory_last_edit, - agent_manager=agent_manager, - message_manager=message_manager, - ) - full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile() - - # Add to the variables list to inject - variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string - - if template_format == "f-string": - - # Catch the special case where the system prompt is unformatted - if append_icm_if_missing: - memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}" - if memory_variable_string not in system_prompt: - # In this case, append it to the end to make sure memory is still injected - # warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead") - system_prompt += "\n" + memory_variable_string - - # render the variables using the built-in templater - try: - formatted_prompt = system_prompt.format_map(variables) - except Exception as e: - raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}") - - else: - # TODO support for mustache and jinja2 - raise NotImplementedError(template_format) - - return formatted_prompt - - -def initialize_message_sequence( - model: str, - system: str, - agent_id: str, - memory: Memory, - actor: PydanticUser, - agent_manager: Optional[AgentManager] = None, - message_manager: Optional[MessageManager] = None, - memory_edit_timestamp: Optional[datetime.datetime] = None, - include_initial_boot_message: bool = True, -) -> List[dict]: - if memory_edit_timestamp is None: - memory_edit_timestamp = get_local_time() - - # full_system_message = construct_system_with_memory( - # system, memory, memory_edit_timestamp, agent_manager=agent_manager, recall_memory=recall_memory - # ) - full_system_message = compile_system_message( - agent_id=agent_id, - system_prompt=system, - in_context_memory=memory, - in_context_memory_last_edit=memory_edit_timestamp, - actor=actor, - agent_manager=agent_manager, - message_manager=message_manager, - user_defined_variables=None, - append_icm_if_missing=True, - ) - first_user_message = get_login_event() # event letting Letta know the user just logged in - - if include_initial_boot_message: - if model is not None and "gpt-3.5" in model: - initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") - else: - initial_boot_messages = get_initial_boot_messages("startup_with_send_message") - messages = ( - [ - {"role": "system", "content": full_system_message}, - ] - + initial_boot_messages - + [ - {"role": "user", "content": first_user_message}, - ] - ) - - else: - messages = [ - {"role": "system", "content": full_system_message}, - {"role": "user", "content": first_user_message}, - ] - - return messages - - class BaseAgent(ABC): """ Abstract class for all agents. @@ -249,7 +100,6 @@ def __init__( agent_state: AgentState, # in-memory representation of the agent state (read from multiple tables) user: User, # extras - messages_total: Optional[int] = None, # TODO remove? first_message_verify_mono: bool = True, # TODO move to config? initial_message_sequence: Optional[List[Message]] = None, ): @@ -303,82 +153,6 @@ def __init__( # When the summarizer is run, set this back to False (to reset) self.agent_alerted_about_memory_pressure = False - self._messages: List[Message] = [] - - # Once the memory object is initialized, use it to "bake" the system message - if self.agent_state.message_ids is not None: - self.set_message_buffer(message_ids=self.agent_state.message_ids) - - else: - printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}") - assert self.agent_state.id is not None and self.agent_state.created_by_id is not None - - # Generate a sequence of initial messages to put in the buffer - init_messages = initialize_message_sequence( - model=self.model, - system=self.agent_state.system, - agent_id=self.agent_state.id, - memory=self.agent_state.memory, - actor=self.user, - agent_manager=None, - message_manager=None, - memory_edit_timestamp=get_utc_time(), - include_initial_boot_message=True, - ) - - if initial_message_sequence is not None: - # We always need the system prompt up front - system_message_obj = Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict=init_messages[0], - ) - # Don't use anything else in the pregen sequence, instead use the provided sequence - init_messages = [system_message_obj] + initial_message_sequence - - else: - # Basic "more human than human" initial message sequence - init_messages = initialize_message_sequence( - model=self.model, - system=self.agent_state.system, - memory=self.agent_state.memory, - agent_id=self.agent_state.id, - actor=self.user, - agent_manager=None, - message_manager=None, - memory_edit_timestamp=get_utc_time(), - include_initial_boot_message=True, - ) - # Cast to Message objects - init_messages = [ - Message.dict_to_message( - agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg - ) - for msg in init_messages - ] - - # Cast the messages to actual Message objects to be synced to the DB - init_messages_objs = [] - for msg in init_messages: - init_messages_objs.append(msg) - for msg in init_messages_objs: - assert isinstance(msg, Message), f"Message object is not of type Message: {type(msg)}" - assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages) - - # Put the messages inside the message buffer - self.messages_total = 0 - self._append_to_messages(added_messages=init_messages_objs) - self._validate_message_buffer_is_utc() - - # Keep track of the total number of messages throughout all time - self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) - self.messages_total_init = len(self._messages) - 1 - printd(f"Agent initialized, self.messages_total={self.messages_total}") - - # Create the agent in the DB - self.update_state() - def check_tool_rules(self): if self.model not in STRUCTURED_OUTPUT_MODELS: if len(self.tool_rules_solver.init_tool_rules) > 1: @@ -470,109 +244,6 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict return function_response - @property - def messages(self) -> List[dict]: - """Getter method that converts the internal Message list into OpenAI-style dicts""" - return [msg.to_openai_dict() for msg in self._messages] - - @messages.setter - def messages(self, value): - raise Exception("Modifying message list directly not allowed") - - def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]: - """Load a list of messages from recall storage""" - - # Pull the message objects from the database - message_objs = [] - for msg_id in message_ids: - msg_obj = self.message_manager.get_message_by_id(msg_id, actor=self.user) - if msg_obj: - if isinstance(msg_obj, Message): - message_objs.append(msg_obj) - else: - printd(f"Warning - message ID {msg_id} is not a Message object") - warnings.warn(f"Warning - message ID {msg_id} is not a Message object") - else: - printd(f"Warning - message ID {msg_id} not found in recall storage") - warnings.warn(f"Warning - message ID {msg_id} not found in recall storage") - - return message_objs - - def _validate_message_buffer_is_utc(self): - """Iterate over the message buffer and force all messages to be UTC stamped""" - - for m in self._messages: - # assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}" - # TODO eventually do casting via an edit_message function - if m.created_at: - if not is_utc_datetime(m.created_at): - printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')") - m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc) - - def set_message_buffer(self, message_ids: List[str], force_utc: bool = True): - """Set the messages in the buffer to the message IDs list""" - - message_objs = self._load_messages_from_recall(message_ids=message_ids) - - # set the objects in the buffer - self._messages = message_objs - - # bugfix for old agents that may not have had UTC specified in their timestamps - if force_utc: - self._validate_message_buffer_is_utc() - - # also sync the message IDs attribute - self.agent_state.message_ids = message_ids - - def refresh_message_buffer(self): - """Refresh the message buffer from the database""" - - messages_to_sync = self.agent_state.message_ids - assert messages_to_sync and all([isinstance(msg_id, str) for msg_id in messages_to_sync]) - - self.set_message_buffer(message_ids=messages_to_sync) - - def _trim_messages(self, num): - """Trim messages from the front, not including the system message""" - new_messages = [self._messages[0]] + self._messages[num:] - self._messages = new_messages - - def _prepend_to_messages(self, added_messages: List[Message]): - """Wrapper around self.messages.prepend to allow additional calls to a state/persistence manager""" - assert all([isinstance(msg, Message) for msg in added_messages]) - self.message_manager.create_many_messages(added_messages, actor=self.user) - - new_messages = [self._messages[0]] + added_messages + self._messages[1:] # prepend (no system) - self._messages = new_messages - self.messages_total += len(added_messages) # still should increment the message counter (summaries are additions too) - - def _append_to_messages(self, added_messages: List[Message]): - """Wrapper around self.messages.append to allow additional calls to a state/persistence manager""" - assert all([isinstance(msg, Message) for msg in added_messages]) - self.message_manager.create_many_messages(added_messages, actor=self.user) - - # strip extra metadata if it exists - # for msg in added_messages: - # msg.pop("api_response", None) - # msg.pop("api_args", None) - new_messages = self._messages + added_messages # append - - self._messages = new_messages - self.messages_total += len(added_messages) - - def append_to_messages(self, added_messages: List[dict]): - """An external-facing message append, where dict-like messages are first converted to Message objects""" - added_messages_objs = [ - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict=msg, - ) - for msg in added_messages - ] - self._append_to_messages(added_messages_objs) - def _get_ai_reply( self, message_sequence: List[Message], @@ -1002,33 +673,19 @@ def inner_step( if not all(isinstance(m, Message) for m in messages): raise ValueError(f"messages should be a Message or a list of Message, got {type(messages)}") - input_message_sequence = self._messages + messages + in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) + input_message_sequence = in_context_messages + messages if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user": printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") # Step 2: send the conversation and available functions to the LLM - if not skip_verify and (first_message or self.messages_total == self.messages_total_init): - printd(f"This is the first message. Running extra verifier on AI response.") - counter = 0 - while True: - response = self._get_ai_reply( - message_sequence=input_message_sequence, first_message=True, stream=stream # passed through to the prompt formatter - ) - if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): - break - - counter += 1 - if counter > first_message_retry_limit: - raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") - - else: - response = self._get_ai_reply( - message_sequence=input_message_sequence, - first_message=first_message, - stream=stream, - step_count=step_count, - ) + response = self._get_ai_reply( + message_sequence=input_message_sequence, + first_message=first_message, + stream=stream, + step_count=step_count, + ) # Step 3: check if LLM wanted to call a function # (if yes) Step 4: call the function @@ -1076,10 +733,7 @@ def inner_step( f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) - self._append_to_messages(all_new_messages) - - # update state after each step - self.update_state() + self.agent_manager.append_to_in_context_messages(all_new_messages, agent_id=agent_state.id, actor=self.user) return AgentStepResponse( messages=all_new_messages, @@ -1094,7 +748,9 @@ def inner_step( # If we got a context alert, try trimming the messages length, then try again if is_context_overflow_error(e): - printd(f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages") + printd( + f"context window exceeded with limit {self.agent_state.llm_config.context_window}, running summarizer to trim messages" + ) # A separate API call to run a summarizer self.summarize_messages_inplace() @@ -1146,15 +802,19 @@ def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepRespons return self.inner_step(messages=[user_message], **kwargs) def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): - assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" + in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + + if in_context_messages_openai[0]["role"] != "system": + raise RuntimeError(f"in_context_messages_openai[0] should be system (instead got {in_context_messages_openai[0]})") # Start at index 1 (past the system message), # and collect messages for summarization until we reach the desired truncation token fraction (eg 50%) # Do not allow truncation of the last N messages, since these are needed for in-context examples of function calling - token_counts = [count_tokens(str(msg)) for msg in self.messages] + token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai] message_buffer_token_count = sum(token_counts[1:]) # no system message desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC) - candidate_messages_to_summarize = self.messages[1:] + candidate_messages_to_summarize = in_context_messages_openai[1:] token_counts = token_counts[1:] if preserve_last_N_messages: @@ -1174,7 +834,7 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, "Not enough messages to compress for summarization", details={ "num_candidate_messages": len(candidate_messages_to_summarize), - "num_total_messages": len(self.messages), + "num_total_messages": len(in_context_messages_openai), "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, }, ) @@ -1193,9 +853,9 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, # Try to make an assistant message come after the cutoff try: printd(f"Selected cutoff {cutoff} was a 'user', shifting one...") - if self.messages[cutoff]["role"] == "user": + if in_context_messages_openai[cutoff]["role"] == "user": new_cutoff = cutoff + 1 - if self.messages[new_cutoff]["role"] == "user": + if in_context_messages_openai[new_cutoff]["role"] == "user": printd(f"Shifted cutoff {new_cutoff} is still a 'user', ignoring...") cutoff = new_cutoff except IndexError: @@ -1203,23 +863,23 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, # Make sure the cutoff isn't on a 'tool' or 'function' if disallow_tool_as_first: - while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages): + while in_context_messages_openai[cutoff]["role"] in ["tool", "function"] and cutoff < len(in_context_messages_openai): printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...") cutoff += 1 - message_sequence_to_summarize = self._messages[1:cutoff] # do NOT get rid of the system message + message_sequence_to_summarize = in_context_messages[1:cutoff] # do NOT get rid of the system message if len(message_sequence_to_summarize) <= 1: # This prevents a potential infinite loop of summarizing the same message over and over raise ContextWindowExceededError( "Not enough messages to compress for summarization after determining cutoff", details={ "num_candidate_messages": len(message_sequence_to_summarize), - "num_total_messages": len(self.messages), + "num_total_messages": len(in_context_messages_openai), "preserve_N": MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, }, ) else: - printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self._messages)}") + printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(in_context_messages)}") # We can't do summarize logic properly if context_window is undefined if self.agent_state.llm_config.context_window is None: @@ -1235,48 +895,32 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, # Metadata that's useful for the agent to see all_time_message_count = self.messages_total - remaining_message_count = len(self.messages[cutoff:]) + remaining_message_count = len(in_context_messages_openai[cutoff:]) hidden_message_count = all_time_message_count - remaining_message_count summary_message_count = len(message_sequence_to_summarize) summary_message = package_summarize_message(summary, summary_message_count, hidden_message_count, all_time_message_count) printd(f"Packaged into message: {summary_message}") - prior_len = len(self.messages) - self._trim_messages(cutoff) + prior_len = len(in_context_messages_openai) + self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user) packed_summary_message = {"role": "user", "content": summary_message} - self._prepend_to_messages( - [ + self.agent_manager.prepend_to_in_context_messages( + messages=[ Message.dict_to_message( agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=packed_summary_message, ) - ] + ], + agent_id=self.agent_state.id, + actor=self.user, ) # reset alert self.agent_alerted_about_memory_pressure = False - printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}") - - def _swap_system_message_in_buffer(self, new_system_message: str): - """Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)""" - assert isinstance(new_system_message, str) - new_system_message_obj = Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={"role": "system", "content": new_system_message}, - ) - - assert new_system_message_obj.role == "system", new_system_message_obj - assert self._messages[0].role == "system", self._messages - - self.message_manager.create_message(new_system_message_obj, actor=self.user) - - new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system) - self._messages = new_messages + printd(f"Ran summarizer, messages length {prior_len} -> {len(in_context_messages_openai)}") def rebuild_system_prompt(self, force=False, update_timestamp=True): """Rebuilds the system message with the latest memory object and any shared memory block updates @@ -1286,12 +930,15 @@ def rebuild_system_prompt(self, force=False, update_timestamp=True): Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ - curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt + curr_system_message = self.agent_manager.get_system_message( + agent_id=self.agent_state.id, actor=self.user + ) # this is the system + memory bank, not just the system prompt + curr_system_message_openai = curr_system_message.to_openai_dict() # note: we only update the system prompt if the core memory is changed # this means that the archival/recall memory statistics may be someout out of date curr_memory_str = self.agent_state.memory.compile() - if curr_memory_str in curr_system_message["content"] and not force: + if curr_memory_str in curr_system_message_openai["content"] and not force: # NOTE: could this cause issues if a block is removed? (substring match would still work) printd(f"Memory hasn't changed, skipping system prompt rebuild") return @@ -1302,7 +949,7 @@ def rebuild_system_prompt(self, force=False, update_timestamp=True): memory_edit_timestamp = get_utc_time() else: # NOTE: a bit of a hack - we pull the timestamp from the message created_by - memory_edit_timestamp = self._messages[0].created_at + memory_edit_timestamp = curr_system_message.created_at # update memory (TODO: potentially update recall/archival stats separately) new_system_message_str = compile_system_message( @@ -1316,21 +963,13 @@ def rebuild_system_prompt(self, force=False, update_timestamp=True): user_defined_variables=None, append_icm_if_missing=True, ) - new_system_message = { - "role": "system", - "content": new_system_message_str, - } - diff = united_diff(curr_system_message["content"], new_system_message["content"]) + diff = united_diff(curr_system_message_openai["content"], new_system_message_str) if len(diff) > 0: # there was a diff printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") # Swap the system message out (only if there is a diff) - self._swap_system_message_in_buffer(new_system_message=new_system_message_str) - assert self.messages[0]["content"] == new_system_message["content"], ( - self.messages[0]["content"], - new_system_message["content"], - ) + self.agent_manager.swap_system_message(new_system_message=new_system_message_str, agent_id=self.agent_state.id, actor=self.user) def update_system_prompt(self, new_system_prompt: str): """Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)""" @@ -1344,9 +983,6 @@ def update_system_prompt(self, new_system_prompt: str): # updating the system prompt requires rebuilding the memory block inside the compiled system message self.rebuild_system_prompt(force=True, update_timestamp=False) - # make sure to persist the change - _ = self.update_state() - def add_function(self, function_name: str) -> str: # TODO: refactor raise NotImplementedError @@ -1355,20 +991,6 @@ def remove_function(self, function_name: str) -> str: # TODO: refactor raise NotImplementedError - def update_state(self) -> AgentState: - # TODO: this should be removed and self._messages should be moved into self.agent_state.in_context_messages - message_ids = [msg.id for msg in self._messages] - - # Assert that these are all strings - if any(not isinstance(m_id, str) for m_id in message_ids): - warnings.warn(f"Non-string message IDs found in agent state: {message_ids}") - message_ids = [m_id for m_id in message_ids if isinstance(m_id, str)] - - # override any fields that may have been updated - self.agent_state.message_ids = message_ids - - return self.agent_state - def migrate_embedding(self, embedding_config: EmbeddingConfig): """Migrate the agent to a new embedding""" # TODO: archival memory @@ -1408,117 +1030,6 @@ def update_message(self, message_id: str, request: MessageUpdate) -> Message: updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user) return updated_message - # TODO(sarah): should we be creating a new message here, or just editing a message? - def rethink_message(self, new_thought: str) -> Message: - """Rethink / update the last message""" - for x in range(len(self.messages) - 1, 0, -1): - msg_obj = self._messages[x] - if msg_obj.role == MessageRole.assistant: - updated_message = self.update_message( - message_id=msg_obj.id, - request=MessageUpdate( - text=new_thought, - ), - ) - self.refresh_message_buffer() - return updated_message - raise ValueError(f"No assistant message found to update") - - # TODO(sarah): should we be creating a new message here, or just editing a message? - def rewrite_message(self, new_text: str) -> Message: - """Rewrite / update the send_message text on the last message""" - - # Walk backwards through the messages until we find an assistant message - for x in range(len(self._messages) - 1, 0, -1): - if self._messages[x].role == MessageRole.assistant: - # Get the current message content - message_obj = self._messages[x] - - # The rewrite target is the output of send_message - if message_obj.tool_calls is not None and len(message_obj.tool_calls) > 0: - - # Check that we hit an assistant send_message call - name_string = message_obj.tool_calls[0].function.name - if name_string is None or name_string != "send_message": - raise ValueError("Assistant missing send_message function call") - - args_string = message_obj.tool_calls[0].function.arguments - if args_string is None: - raise ValueError("Assistant missing send_message function arguments") - - args_json = json_loads(args_string) - if "message" not in args_json: - raise ValueError("Assistant missing send_message message argument") - - # Once we found our target, rewrite it - args_json["message"] = new_text - new_args_string = json_dumps(args_json) - message_obj.tool_calls[0].function.arguments = new_args_string - - # Write the update to the DB - updated_message = self.update_message( - message_id=message_obj.id, - request=MessageUpdate( - tool_calls=message_obj.tool_calls, - ), - ) - self.refresh_message_buffer() - return updated_message - - raise ValueError("No assistant message found to update") - - def pop_message(self, count: int = 1) -> List[Message]: - """Pop the last N messages from the agent's memory""" - n_messages = len(self._messages) - popped_messages = [] - MIN_MESSAGES = 2 - if n_messages <= MIN_MESSAGES: - raise ValueError(f"Agent only has {n_messages} messages in stack, none left to pop") - elif n_messages - count < MIN_MESSAGES: - raise ValueError(f"Agent only has {n_messages} messages in stack, cannot pop more than {n_messages - MIN_MESSAGES}") - else: - # print(f"Popping last {count} messages from stack") - for _ in range(min(count, len(self._messages))): - # remove the message from the internal state of the agent - deleted_message = self._messages.pop() - # then also remove it from recall storage - try: - self.message_manager.delete_message_by_id(deleted_message.id, actor=self.user) - popped_messages.append(deleted_message) - except Exception as e: - warnings.warn(f"Error deleting message {deleted_message.id} from recall memory: {e}") - self._messages.append(deleted_message) - break - - return popped_messages - - def pop_until_user(self) -> List[Message]: - """Pop all messages until the last user message""" - if MessageRole.user not in [msg.role for msg in self._messages]: - raise ValueError("No user message found in buffer") - - popped_messages = [] - while len(self._messages) > 0: - if self._messages[-1].role == MessageRole.user: - # we want to pop up to the last user message - return popped_messages - else: - popped_messages.append(self.pop_message(count=1)) - - raise ValueError("No user message found in buffer") - - def retry_message(self) -> List[Message]: - """Retry / regenerate the last message""" - self.pop_until_user() - user_message = self.pop_message(count=1)[0] - assert user_message.text is not None, "User message text is None" - step_response = self.step_user_message(user_message_str=user_message.text) - messages = step_response.messages - - assert messages is not None - assert all(isinstance(msg, Message) for msg in messages), "step() returned non-Message objects" - return messages - def get_context_window(self) -> ContextWindowOverview: """Get the context window of the agent""" @@ -1527,24 +1038,28 @@ def get_context_window(self) -> ContextWindowOverview: core_memory = self.agent_state.memory.compile() num_tokens_core_memory = count_tokens(core_memory) + # Grab the in-context messages # conversion of messages to OpenAI dict format, which is passed to the token counter - messages_openai_format = self.messages + in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] # Check if there's a summary message in the message queue if ( - len(self._messages) > 1 - and self._messages[1].role == MessageRole.user - and isinstance(self._messages[1].text, str) + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and isinstance(in_context_messages[1].text, str) # TODO remove hardcoding - and "The following is a summary of the previous " in self._messages[1].text + and "The following is a summary of the previous " in in_context_messages[1].text ): # Summary message exists - assert self._messages[1].text is not None - summary_memory = self._messages[1].text - num_tokens_summary_memory = count_tokens(self._messages[1].text) + assert in_context_messages[1].text is not None + summary_memory = in_context_messages[1].text + num_tokens_summary_memory = count_tokens(in_context_messages[1].text) # with a summary message, the real messages start at index 2 num_tokens_messages = ( - num_tokens_from_messages(messages=messages_openai_format[2:], model=self.model) if len(messages_openai_format) > 2 else 0 + num_tokens_from_messages(messages=in_context_messages_openai[2:], model=self.model) + if len(in_context_messages_openai) > 2 + else 0 ) else: @@ -1552,7 +1067,9 @@ def get_context_window(self) -> ContextWindowOverview: num_tokens_summary_memory = 0 # with no summary message, the real messages start at index 1 num_tokens_messages = ( - num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0 + num_tokens_from_messages(messages=in_context_messages_openai[1:], model=self.model) + if len(in_context_messages_openai) > 1 + else 0 ) agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id) @@ -1587,7 +1104,7 @@ def get_context_window(self) -> ContextWindowOverview: return ContextWindowOverview( # context window breakdown (in messages) - num_messages=len(self._messages), + num_messages=len(in_context_messages), num_archival_memory=agent_manager_passage_size, num_recall_memory=message_manager_size, num_tokens_external_memory_summary=num_tokens_external_memory_summary, @@ -1602,7 +1119,7 @@ def get_context_window(self) -> ContextWindowOverview: num_tokens_summary_memory=num_tokens_summary_memory, summary_memory=summary_memory, num_tokens_messages=num_tokens_messages, - messages=self._messages, + messages=in_context_messages, # related to functions num_tokens_functions_definitions=num_tokens_available_functions_definitions, functions_definitions=available_functions_definitions, @@ -1616,7 +1133,6 @@ def count_tokens(self) -> int: def save_agent(agent: Agent): """Save agent to metadata store""" - agent.update_state() agent_state = agent.agent_state assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}" diff --git a/letta/server/server.py b/letta/server/server.py index 1b8ff9cdc7..8a6cf54844 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -846,15 +846,6 @@ def update_agent( if request.system: letta_agent.update_system_prompt(request.system) - # update in-context messages - if request.message_ids: - # This means the user is trying to change what messages are in the message buffer - # Internally this requires (1) pulling from recall, - # then (2) setting the attributes ._messages and .state.message_ids - letta_agent.set_message_buffer(message_ids=request.message_ids) - - letta_agent.update_state() - return agent_state def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index aacad8ae49..87c2679afc 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -14,12 +14,14 @@ from letta.orm import SourcePassage, SourcesAgents from letta.orm import Tool as ToolModel from letta.orm.errors import NoResultFound +from letta.orm.message import Message as MessageModel from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source as PydanticSource from letta.schemas.tool_rule import ToolRule as PydanticToolRule @@ -29,11 +31,13 @@ _process_relationship, _process_tags, derive_system_message, + initialize_message_sequence, ) +from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings -from letta.utils import enforce_types +from letta.utils import enforce_types, get_utc_time logger = get_logger(__name__) @@ -49,6 +53,7 @@ def __init__(self): self.block_manager = BlockManager() self.tool_manager = ToolManager() self.source_manager = SourceManager() + self.message_manager = MessageManager() # ====================================================================================================================== # Basic CRUD operations @@ -88,7 +93,8 @@ def create_agent( # Remove duplicates tool_ids = list(set(tool_ids)) - return self._create_agent( + # Create the agent + agent_state = self._create_agent( name=agent_create.name, system=system, agent_type=agent_create.agent_type, @@ -104,6 +110,50 @@ def create_agent( actor=actor, ) + # TODO: See if we can merge this into the above SQL create call for performance reasons + # Generate a sequence of initial messages to put in the buffer + init_messages = initialize_message_sequence( + agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True + ) + + if agent_create.initial_message_sequence is not None: + # We always need the system prompt up front + system_message_obj = PydanticMessage.dict_to_message( + agent_id=agent_state.id, + user_id=agent_state.created_by_id, + model=agent_state.llm_config.model, + openai_message_dict=init_messages[0], + ) + # Don't use anything else in the pregen sequence, instead use the provided sequence + init_messages = [system_message_obj] + # TODO: Get rid of this ad-hoc message creation, and make message_manager ONLY accept CreateMessages + for create_req in agent_create.initial_message_sequence: + init_messages.append( + PydanticMessage( + role=create_req.role, + text=create_req.text, + organization_id=actor.organization_id, + agent_id=agent_state.id, + model=agent_state.llm_config.model, + ) + ) + else: + # Basic "more human than human" initial message sequence + init_messages = initialize_message_sequence( + agent_state=agent_state, + memory_edit_timestamp=get_utc_time(), + include_initial_boot_message=True, + ) + # Cast to Message objects + init_messages = [ + PydanticMessage.dict_to_message( + agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg + ) + for msg in init_messages + ] + + return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor) + @enforce_types def _create_agent( self, @@ -247,6 +297,78 @@ def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState agent.hard_delete(session) return agent_state + def get_total_message_count(self, agent_id: str, actor: PydanticUser) -> int: + """ + Get the total number of messages associated with a specific agent using the relationship. + + :param session: The SQLAlchemy session object. + :param agent_id: The ID of the agent. + :return: The total number of messages for the agent. + """ + with self.session_maker() as session: + return ( + session.query(func.count(MessageModel.id)) + .join(AgentModel.messages) + .filter(AgentModel.id == agent_id) + .filter(AgentModel.organization_id == actor.organization_id) + .scalar() + ) + + # ====================================================================================================================== + # In Context Messages Management + # ====================================================================================================================== + # TODO: There are several assumptions here that are not explicitly checked + # TODO: 1) These message ids are valid + # TODO: 2) These messages are ordered from oldest to newest + # TODO: This can be fixed by having an actual relationship in the ORM for message_ids + # TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query. + @enforce_types + def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) + + @enforce_types + def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) + + @enforce_types + def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: + return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) + + @enforce_types + def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message + return self.set_in_context_messages(agent_id=agent_id, message_ids=[m.id for m in new_messages], actor=actor) + + @enforce_types + def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + new_messages = self.message_manager.create_many_messages(messages, actor=actor) + message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:] + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + + @enforce_types + def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: + messages = self.message_manager.create_many_messages(messages, actor=actor) + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or [] + message_ids += [m.id for m in messages] + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + + @enforce_types + def swap_system_message(self, new_system_message: str, agent_id: str, actor: PydanticUser): + agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) + message = PydanticMessage.dict_to_message( + agent_id=agent_id, + user_id=actor.id, + model=agent_state.llm_config.model, + openai_message_dict={"role": "system", "content": new_system_message}, + ) + message = self.message_manager.create_message(message, actor=actor) + message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system) + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + # ====================================================================================================================== # Source Management # ====================================================================================================================== diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 95ad26beeb..c7049d5f44 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1,10 +1,15 @@ -from typing import List, Optional +import datetime +from typing import List, Literal, Optional +from letta.constants import IN_CONTEXT_MEMORY_KEYWORD from letta.orm.agent import Agent as AgentModel from letta.orm.agents_tags import AgentsTags from letta.orm.errors import NoResultFound from letta.prompts import gpt_system -from letta.schemas.agent import AgentType +from letta.schemas.agent import AgentState, AgentType +from letta.schemas.memory import Memory +from letta.system import get_initial_boot_messages, get_login_event +from letta.utils import get_local_time # Static methods @@ -88,3 +93,129 @@ def derive_system_message(agent_type: AgentType, system: Optional[str] = None): raise ValueError(f"Invalid agent type: {agent_type}") return system + + +# TODO: This code is kind of wonky and deserves a rewrite +def compile_memory_metadata_block( + memory_edit_timestamp: datetime.datetime, previous_message_count: int = 0, archival_memory_size: int = 0 +) -> str: + # Put the timestamp in the local timezone (mimicking get_local_time()) + timestamp_str = memory_edit_timestamp.astimezone().strftime("%Y-%m-%d %I:%M:%S %p %Z%z").strip() + + # Create a metadata block of info so the agent knows about the metadata of out-of-context memories + memory_metadata_block = "\n".join( + [ + f"### Memory [last modified: {timestamp_str}]", + f"{previous_message_count} previous messages between you and the user are stored in recall memory (use functions to access them)", + f"{archival_memory_size} total memories you created are stored in archival memory (use functions to access them)", + "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", + ] + ) + return memory_metadata_block + + +def compile_system_message( + system_prompt: str, + in_context_memory: Memory, + in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? + user_defined_variables: Optional[dict] = None, + append_icm_if_missing: bool = True, + template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + previous_message_count: int = 0, + archival_memory_size: int = 0, +) -> str: + """Prepare the final/full system message that will be fed into the LLM API + + The base system message may be templated, in which case we need to render the variables. + + The following are reserved variables: + - CORE_MEMORY: the in-context memory of the LLM + """ + + if user_defined_variables is not None: + # TODO eventually support the user defining their own variables to inject + raise NotImplementedError + else: + variables = {} + + # Add the protected memory variable + if IN_CONTEXT_MEMORY_KEYWORD in variables: + raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}") + else: + # TODO should this all put into the memory.__repr__ function? + memory_metadata_string = compile_memory_metadata_block( + memory_edit_timestamp=in_context_memory_last_edit, + previous_message_count=previous_message_count, + archival_memory_size=archival_memory_size, + ) + full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile() + + # Add to the variables list to inject + variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string + + if template_format == "f-string": + + # Catch the special case where the system prompt is unformatted + if append_icm_if_missing: + memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}" + if memory_variable_string not in system_prompt: + # In this case, append it to the end to make sure memory is still injected + # warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead") + system_prompt += "\n" + memory_variable_string + + # render the variables using the built-in templater + try: + formatted_prompt = system_prompt.format_map(variables) + except Exception as e: + raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}") + + else: + # TODO support for mustache and jinja2 + raise NotImplementedError(template_format) + + return formatted_prompt + + +def initialize_message_sequence( + agent_state: AgentState, + memory_edit_timestamp: Optional[datetime.datetime] = None, + include_initial_boot_message: bool = True, + previous_message_count: int = 0, + archival_memory_size: int = 0, +) -> List[dict]: + if memory_edit_timestamp is None: + memory_edit_timestamp = get_local_time() + + full_system_message = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + user_defined_variables=None, + append_icm_if_missing=True, + previous_message_count=previous_message_count, + archival_memory_size=archival_memory_size, + ) + first_user_message = get_login_event() # event letting Letta know the user just logged in + + if include_initial_boot_message: + if agent_state.llm_config.model is not None and "gpt-3.5" in agent_state.llm_config.model: + initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") + else: + initial_boot_messages = get_initial_boot_messages("startup_with_send_message") + messages = ( + [ + {"role": "system", "content": full_system_message}, + ] + + initial_boot_messages + + [ + {"role": "user", "content": first_user_message}, + ] + ) + + else: + messages = [ + {"role": "system", "content": full_system_message}, + {"role": "user", "content": first_user_message}, + ] + + return messages diff --git a/tests/test_managers.py b/tests/test_managers.py index 588f59ce4f..a7b385ea02 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -370,8 +370,8 @@ def print_other_tool(message: str): @pytest.fixture def sarah_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.create_agent( - request=CreateAgent( + agent_state = server.agent_manager.create_agent( + agent_create=CreateAgent( name="sarah_agent", memory_blocks=[], llm_config=LLMConfig.default_config("gpt-4"), @@ -385,8 +385,8 @@ def sarah_agent(server: SyncServer, default_user, default_organization): @pytest.fixture def charles_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" - agent_state = server.create_agent( - request=CreateAgent( + agent_state = server.agent_manager.create_agent( + agent_create=CreateAgent( name="charles_agent", memory_blocks=[CreateBlock(label="human", value="Charles"), CreateBlock(label="persona", value="I am a helpful assistant")], llm_config=LLMConfig.default_config("gpt-4"), @@ -503,6 +503,54 @@ def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixt assert len(list_agents) == 0 +def test_create_agent_passed_in_initial_messages(server: SyncServer, default_user, default_block): + memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] + create_agent_request = CreateAgent( + system="test system", + memory_blocks=memory_blocks, + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + block_ids=[default_block.id], + tags=["a", "b"], + description="test_description", + initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")], + ) + agent_state = server.agent_manager.create_agent( + create_agent_request, + actor=default_user, + ) + assert server.agent_manager.get_total_message_count(agent_state.id, default_user) == 2 + init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + # Check that the system appears in the first initial message + assert create_agent_request.system in init_messages[0].text + assert create_agent_request.memory_blocks[0].value in init_messages[0].text + # Check that the second message is the passed in initial message seq + assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role + assert create_agent_request.initial_message_sequence[0].text == init_messages[1].text + + +def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block): + memory_blocks = [CreateBlock(label="human", value="BananaBoy"), CreateBlock(label="persona", value="I am a helpful assistant")] + create_agent_request = CreateAgent( + system="test system", + memory_blocks=memory_blocks, + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + block_ids=[default_block.id], + tags=["a", "b"], + description="test_description", + ) + agent_state = server.agent_manager.create_agent( + create_agent_request, + actor=default_user, + ) + assert server.agent_manager.get_total_message_count(agent_state.id, default_user) == 4 + init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + # Check that the system appears in the first initial message + assert create_agent_request.system in init_messages[0].text + assert create_agent_request.memory_blocks[0].value in init_messages[0].text + + def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, other_tool, other_source, other_block, default_user): agent, _ = comprehensive_test_agent_fixture update_agent_request = UpdateAgent( @@ -794,8 +842,8 @@ def test_list_agents_by_tags_with_other_filters(server: SyncServer, sarah_agent, def test_list_agents_by_tags_pagination(server: SyncServer, default_user, default_organization): """Test pagination when listing agents by tags.""" # Create first agent - agent1 = server.create_agent( - request=CreateAgent( + agent1 = server.agent_manager.create_agent( + agent_create=CreateAgent( name="agent1", tags=["pagination_test", "tag1"], llm_config=LLMConfig.default_config("gpt-4"), @@ -809,8 +857,8 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul time.sleep(CREATE_DELAY_SQLITE) # Ensure distinct created_at timestamps # Create second agent - agent2 = server.create_agent( - request=CreateAgent( + agent2 = server.agent_manager.create_agent( + agent_create=CreateAgent( name="agent2", tags=["pagination_test", "tag2"], llm_config=LLMConfig.default_config("gpt-4"), From 9dc97b15548aec4f8fd8a00ef01d99ab884a4ea3 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 12:02:37 -0800 Subject: [PATCH 03/11] First pass at removing self.load_agent --- letta/agent.py | 94 +-------- letta/client/client.py | 10 +- letta/main.py | 76 ------- letta/server/rest_api/routers/v1/agents.py | 17 +- letta/server/server.py | 194 ++---------------- letta/services/agent_manager.py | 121 ++++++----- .../services/helpers/agent_manager_helper.py | 28 +++ tests/helpers/endpoints_helper.py | 11 +- tests/test_managers.py | 6 +- tests/test_server.py | 73 +------ 10 files changed, 154 insertions(+), 476 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 6ed305a171..44a424e196 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -31,7 +31,7 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.memory import ContextWindowOverview, Memory -from letta.schemas.message import Message, MessageUpdate +from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import ( Tool as ChatCompletionRequestTool, ) @@ -46,6 +46,7 @@ from letta.schemas.user import User as PydanticUser from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.helpers.agent_manager_helper import compile_memory_metadata_block from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager @@ -67,7 +68,6 @@ json_loads, parse_json, printd, - united_diff, validate_function_response, ) @@ -75,7 +75,7 @@ class BaseAgent(ABC): """ Abstract class for all agents. - Only two interfaces are required: step and update_state. + Only one interface is required: step. """ @abstractmethod @@ -88,10 +88,6 @@ def step( """ raise NotImplementedError - @abstractmethod - def update_state(self) -> AgentState: - raise NotImplementedError - class Agent(BaseAgent): def __init__( @@ -101,7 +97,6 @@ def __init__( user: User, # extras first_message_verify_mono: bool = True, # TODO move to config? - initial_message_sequence: Optional[List[Message]] = None, ): assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" # Hold a copy of the state that was used to init the agent @@ -192,7 +187,7 @@ def update_memory_if_change(self, new_memory: Memory) -> bool: # NOTE: don't do this since re-buildin the memory is handled at the start of the step # rebuild memory - this records the last edited timestamp of the memory # TODO: pass in update timestamp from block edit time - self.rebuild_system_prompt() + self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) return True return False @@ -550,7 +545,7 @@ def _handle_ai_response( # rebuild memory # TODO: @charles please check this - self.rebuild_system_prompt() + self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) # Update ToolRulesSolver state with last called function self.tool_rules_solver.update_tool_usage(function_name) @@ -733,7 +728,7 @@ def inner_step( f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) - self.agent_manager.append_to_in_context_messages(all_new_messages, agent_id=agent_state.id, actor=self.user) + self.agent_manager.append_to_in_context_messages(all_new_messages, agent_id=self.agent_state.id, actor=self.user) return AgentStepResponse( messages=all_new_messages, @@ -894,7 +889,7 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, printd(f"Got summary: {summary}") # Metadata that's useful for the agent to see - all_time_message_count = self.messages_total + all_time_message_count = self.message_manager.size(agent_id=self.agent_state.id, actor=self.user) remaining_message_count = len(in_context_messages_openai[cutoff:]) hidden_message_count = all_time_message_count - remaining_message_count summary_message_count = len(message_sequence_to_summarize) @@ -922,67 +917,6 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, printd(f"Ran summarizer, messages length {prior_len} -> {len(in_context_messages_openai)}") - def rebuild_system_prompt(self, force=False, update_timestamp=True): - """Rebuilds the system message with the latest memory object and any shared memory block updates - - Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object - - Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages - """ - - curr_system_message = self.agent_manager.get_system_message( - agent_id=self.agent_state.id, actor=self.user - ) # this is the system + memory bank, not just the system prompt - curr_system_message_openai = curr_system_message.to_openai_dict() - - # note: we only update the system prompt if the core memory is changed - # this means that the archival/recall memory statistics may be someout out of date - curr_memory_str = self.agent_state.memory.compile() - if curr_memory_str in curr_system_message_openai["content"] and not force: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - printd(f"Memory hasn't changed, skipping system prompt rebuild") - return - - # If the memory didn't update, we probably don't want to update the timestamp inside - # For example, if we're doing a system prompt swap, this should probably be False - if update_timestamp: - memory_edit_timestamp = get_utc_time() - else: - # NOTE: a bit of a hack - we pull the timestamp from the message created_by - memory_edit_timestamp = curr_system_message.created_at - - # update memory (TODO: potentially update recall/archival stats separately) - new_system_message_str = compile_system_message( - agent_id=self.agent_state.id, - system_prompt=self.agent_state.system, - in_context_memory=self.agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - actor=self.user, - agent_manager=self.agent_manager, - message_manager=self.message_manager, - user_defined_variables=None, - append_icm_if_missing=True, - ) - - diff = united_diff(curr_system_message_openai["content"], new_system_message_str) - if len(diff) > 0: # there was a diff - printd(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - # Swap the system message out (only if there is a diff) - self.agent_manager.swap_system_message(new_system_message=new_system_message_str, agent_id=self.agent_state.id, actor=self.user) - - def update_system_prompt(self, new_system_prompt: str): - """Update the system prompt of the agent (requires rebuilding the memory block if there's a difference)""" - assert isinstance(new_system_prompt, str) - - if new_system_prompt == self.agent_state.system: - return - - self.agent_state.system = new_system_prompt - - # updating the system prompt requires rebuilding the memory block inside the compiled system message - self.rebuild_system_prompt(force=True, update_timestamp=False) - def add_function(self, function_name: str) -> str: # TODO: refactor raise NotImplementedError @@ -1024,12 +958,6 @@ def attach_source( f"Attached data source {source.name} to agent {self.agent_state.name}.", ) - def update_message(self, message_id: str, request: MessageUpdate) -> Message: - """Update the details of a message associated with an agent""" - # Save the updated message - updated_message = self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=self.user) - return updated_message - def get_context_window(self) -> ContextWindowOverview: """Get the context window of the agent""" @@ -1075,11 +1003,9 @@ def get_context_window(self) -> ContextWindowOverview: agent_manager_passage_size = self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id) message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id) external_memory_summary = compile_memory_metadata_block( - actor=self.user, - agent_id=self.agent_state.id, - memory_edit_timestamp=get_utc_time(), # dummy timestamp - agent_manager=self.agent_manager, - message_manager=self.message_manager, + memory_edit_timestamp=get_utc_time(), + previous_message_count=self.message_manager.size(actor=self.user, agent_id=self.agent_state.id), + archival_memory_size=self.agent_manager.passage_size(actor=self.user, agent_id=self.agent_state.id), ) num_tokens_external_memory_summary = count_tokens(external_memory_summary) diff --git a/letta/client/client.py b/letta/client/client.py index 8a9d3e700a..106d1cc3a5 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2234,7 +2234,7 @@ def update_agent( """ # TODO: add the abilitty to reset linked block_ids self.interface.clear() - agent_state = self.server.update_agent( + agent_state = self.server.agent_manager.update_agent( agent_id, UpdateAgent( name=name, @@ -2262,7 +2262,7 @@ def get_tools_from_agent(self, agent_id: str) -> List[Tool]: List[Tool]: A list of Tool objs """ self.interface.clear() - return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id) + return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools def add_tool_to_agent(self, agent_id: str, tool_id: str): """ @@ -2276,7 +2276,7 @@ def add_tool_to_agent(self, agent_id: str, tool_id: str): agent_state (AgentState): State of the updated agent """ self.interface.clear() - agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id) + agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) return agent_state def remove_tool_from_agent(self, agent_id: str, tool_id: str): @@ -2291,7 +2291,7 @@ def remove_tool_from_agent(self, agent_id: str, tool_id: str): agent_state (AgentState): State of the updated agent """ self.interface.clear() - agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id) + agent_state = self.server.agent_manager.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, actor=self.user) return agent_state def rename_agent(self, agent_id: str, new_name: str): @@ -3075,7 +3075,7 @@ def delete_archival_memory(self, agent_id: str, memory_id: str): agent_id (str): ID of the agent memory_id (str): ID of the memory """ - self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user) + self.server.delete_archival_memory(memory_id=memory_id, actor=self.user) def get_archival_memory( self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 diff --git a/letta/main.py b/letta/main.py index c426917092..de1b4028ab 100644 --- a/letta/main.py +++ b/letta/main.py @@ -194,46 +194,6 @@ def run_agent_loop( print(f"Current model: {letta_agent.agent_state.llm_config.model}") continue - elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "): - # Check if there's an additional argument that's an integer - command = user_input.strip().split() - pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3 - try: - popped_messages = letta_agent.pop_message(count=pop_amount) - except ValueError as e: - print(f"Error popping messages: {e}") - continue - - elif user_input.lower() == "/retry": - print(f"Retrying for another answer...") - try: - letta_agent.retry_message() - except Exception as e: - print(f"Error retrying message: {e}") - continue - - elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "): - if len(user_input) < len("/rethink "): - print("Missing text after the command") - continue - try: - letta_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip()) - except Exception as e: - print(f"Error rethinking message: {e}") - continue - - elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "): - if len(user_input) < len("/rewrite "): - print("Missing text after the command") - continue - - text = user_input[len("/rewrite ") :].strip() - try: - letta_agent.rewrite_message(new_text=text) - except Exception as e: - print(f"Error rewriting message: {e}") - continue - elif user_input.lower() == "/summarize": try: letta_agent.summarize_messages_inplace() @@ -319,42 +279,6 @@ def run_agent_loop( questionary.print(cmd, "bold") questionary.print(f" {desc}") continue - - elif user_input.lower().startswith("/systemswap"): - if len(user_input) < len("/systemswap "): - print("Missing new system prompt after the command") - continue - old_system_prompt = letta_agent.system - new_system_prompt = user_input[len("/systemswap ") :].strip() - - # Show warning and prompts to user - typer.secho( - "\nWARNING: You are about to change the system prompt.", - # fg=typer.colors.BRIGHT_YELLOW, - bold=True, - ) - typer.secho( - f"\nOld system prompt:\n{old_system_prompt}", - fg=typer.colors.RED, - bold=True, - ) - typer.secho( - f"\nNew system prompt:\n{new_system_prompt}", - fg=typer.colors.GREEN, - bold=True, - ) - - # Ask for confirmation - confirm = questionary.confirm("Do you want to proceed with the swap?").ask() - - if confirm: - letta_agent.update_system_prompt(new_system_prompt=new_system_prompt) - print("System prompt updated successfully.") - else: - print("System prompt swap cancelled.") - - continue - else: print(f"Unrecognized command: {user_input}") continue diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 69b97c764e..c46a24573a 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -104,7 +104,7 @@ def get_agent_context_window( """ actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id) + return server.get_agent_context_window(agent_id=agent_id, actor=actor) class CreateAgentRequest(CreateAgent): @@ -138,7 +138,7 @@ def update_agent( ): """Update an exsiting agent""" actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.update_agent(agent_id, update_agent, actor=actor) + return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor) @router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent") @@ -149,7 +149,7 @@ def get_tools_from_agent( ): """Get tools from an existing agent""" actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id) + return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools @router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent") @@ -161,7 +161,7 @@ def add_tool_to_agent( ): """Add tools to an existing agent""" actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) + return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, user_id=actor) @router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent") @@ -173,7 +173,7 @@ def remove_tool_from_agent( ): """Add tools to an existing agent""" actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id) + return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) @router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent") @@ -232,7 +232,7 @@ def get_agent_in_context_messages( Retrieve the messages in the context of a specific agent. """ actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.get_in_context_messages(agent_id=agent_id, actor=actor) + return server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor) # TODO: remove? can also get with agent blocks @@ -429,7 +429,7 @@ def delete_agent_archival_memory( """ actor = server.user_manager.get_user_or_default(user_id=user_id) - server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor) + server.delete_archival_memory(memory_id=memory_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) @@ -479,8 +479,9 @@ def update_message( """ Update the details of a message associated with an agent. """ + # TODO: Get rid of agent_id here, it's not really relevant actor = server.user_manager.get_user_or_default(user_id=user_id) - return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor) + return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) @router.post( diff --git a/letta/server/server.py b/letta/server/server.py index 8a6cf54844..114341d3ef 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -40,7 +40,7 @@ VLLMChatCompletionsProvider, VLLMCompletionsProvider, ) -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent +from letta.schemas.agent import AgentState, AgentType, CreateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -376,25 +376,6 @@ def __init__( ) ) - def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent: - """Initialize an agent from the database""" - agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - - interface = interface or self.default_interface_factory() - if agent_state.agent_type == AgentType.memgpt_agent: - agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence) - elif agent_state.agent_type == AgentType.offline_memory_agent: - agent = OfflineMemoryAgent( - agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence - ) - else: - assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents" - agent = O1Agent(agent_state=agent_state, interface=interface, user=actor) - - # Persist to agent - save_agent(agent) - return agent - def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" agent_lock = self.per_agent_lock_manager.get_lock(agent_id) @@ -413,11 +394,6 @@ def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface else: raise ValueError(f"Invalid agent type {agent_state.agent_type}") - # Rebuild the system prompt - may be linked to new blocks now - agent.rebuild_system_prompt() - - # Persist to agent - save_agent(agent) return agent def _step( @@ -790,120 +766,23 @@ def create_agent( """Create a new agent using a config""" # Invoke manager - agent_state = self.agent_manager.create_agent( + return self.agent_manager.create_agent( agent_create=request, actor=actor, ) - # create the agent object - if request.initial_message_sequence is not None: - # init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence] - init_messages = [] - for message in request.initial_message_sequence: - - if message.role == MessageRole.user: - packed_message = system.package_user_message( - user_message=message.text, - ) - elif message.role == MessageRole.system: - packed_message = system.package_system_message( - system_message=message.text, - ) - else: - raise ValueError(f"Invalid message role: {message.role}") - - init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id)) - # init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence] - else: - init_messages = None - - # initialize the agent (generates initial message list with system prompt) - if interface is None: - interface = self.default_interface_factory() - self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor) - - in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor) - return in_memory_agent_state - - # TODO: This is not good! - # TODO: Ideally, this should ALL be handled by the ORM - # TODO: The main blocker here IS the _message updates - def update_agent( - self, - agent_id: str, - request: UpdateAgent, - actor: User, - ) -> AgentState: - """Update the agents core memory block, return the new state""" - # Update agent state in the db first - agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor) - - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - - # TODO: Everything below needs to get removed, no updating anything in memory - # update the system prompt - if request.system: - letta_agent.update_system_prompt(request.system) - - return agent_state - - def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]: - """Get tools from an existing agent""" - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return letta_agent.agent_state.tools - - def add_tool_to_agent( - self, - agent_id: str, - tool_id: str, - user_id: str, - ): - """Add tools from an existing agent""" - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - - agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) - - return agent_state - - def remove_tool_from_agent( - self, - agent_id: str, - tool_id: str, - user_id: str, - ): - """Remove tools from an existing agent""" - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) - - return agent_state - # convert name->id + # TODO: These can be moved to agent_manager def get_agent_memory(self, agent_id: str, actor: User) -> Memory: """Return the memory of an agent (core memory)""" - agent = self.load_agent(agent_id=agent_id, actor=actor) - return agent.agent_state.memory + return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary: - agent = self.load_agent(agent_id=agent_id, actor=actor) return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id)) def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary: - agent = self.load_agent(agent_id=agent_id, actor=actor) - return RecallMemorySummary(size=len(agent.message_manager)) - - def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]: - """Get the in-context messages in the agent's memory""" - # Get the agent object (loaded in memory) - agent = self.load_agent(agent_id=agent_id, actor=actor) - return agent._messages + return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id)) def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]: """Paginated query of all messages in agent archival memory""" @@ -938,24 +817,17 @@ def get_agent_archival_cursor( def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]: # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) # Insert into archival memory - passages = self.passage_manager.insert_passage( - agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor - ) - - save_agent(letta_agent) + # TODO: @mindy look at moving this to agent_manager to avoid above extra call + passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor) return passages - def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User): - # Get the agent object (loaded in memory) - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - - # Delete by ID + def delete_archival_memory(self, memory_id: str, actor: User): # TODO check if it exists first, and throw error if not - letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor) + # TODO: @mindy make this return the deleted passage instead + self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor) # TODO: return archival memory @@ -1033,9 +905,8 @@ def update_agent_core_memory(self, agent_id: str, label: str, value: str, actor: # update the block self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor) - # load agent - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return letta_agent.agent_state.memory + # rebuild system prompt for agent, potentially changed + return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory def delete_source(self, source_id: str, actor: User): """Delete a data source""" @@ -1205,36 +1076,11 @@ def add_default_external_tools(self, actor: User) -> bool: return success - def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message: + def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message: """Update the details of a message associated with an agent""" # Get the current message - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - response = letta_agent.update_message(message_id=message_id, request=request) - save_agent(letta_agent) - return response - - def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message: - - # Get the current message - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - response = letta_agent.rewrite_message(new_text=new_text) - save_agent(letta_agent) - return response - - def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message: - # Get the current message - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - response = letta_agent.rethink_message(new_thought=new_thought) - save_agent(letta_agent) - return response - - def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]: - # Get the current message - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - response = letta_agent.retry_message() - save_agent(letta_agent) - return response + return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) def get_organization_or_default(self, org_id: Optional[str]) -> Organization: """Get the organization object for org_id if it exists, otherwise return the default organization object""" @@ -1322,15 +1168,7 @@ def add_llm_model(self, request: LLMConfig) -> LLMConfig: def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" - def get_agent_context_window( - self, - user_id: str, - agent_id: str, - ) -> ContextWindowOverview: - # TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user - actor = self.user_manager.get_user_or_default(user_id=user_id) - - # Get the current message + def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview: letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.get_context_window() diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 87c2679afc..fd0ac5bf31 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -14,7 +14,6 @@ from letta.orm import SourcePassage, SourcesAgents from letta.orm import Tool as ToolModel from letta.orm.errors import NoResultFound -from letta.orm.message import Message as MessageModel from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent @@ -30,14 +29,16 @@ from letta.services.helpers.agent_manager_helper import ( _process_relationship, _process_tags, + compile_system_message, derive_system_message, initialize_message_sequence, + package_initial_message_sequence, ) from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings -from letta.utils import enforce_types, get_utc_time +from letta.utils import enforce_types, get_utc_time, united_diff logger = get_logger(__name__) @@ -126,25 +127,10 @@ def create_agent( ) # Don't use anything else in the pregen sequence, instead use the provided sequence init_messages = [system_message_obj] - # TODO: Get rid of this ad-hoc message creation, and make message_manager ONLY accept CreateMessages - for create_req in agent_create.initial_message_sequence: - init_messages.append( - PydanticMessage( - role=create_req.role, - text=create_req.text, - organization_id=actor.organization_id, - agent_id=agent_state.id, - model=agent_state.llm_config.model, - ) - ) - else: - # Basic "more human than human" initial message sequence - init_messages = initialize_message_sequence( - agent_state=agent_state, - memory_edit_timestamp=get_utc_time(), - include_initial_boot_message=True, + init_messages.extend( + package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor) ) - # Cast to Message objects + else: init_messages = [ PydanticMessage.dict_to_message( agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg @@ -199,6 +185,16 @@ def _create_agent( @enforce_types def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState: + agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor) + + # Rebuild the system prompt if it's different + if agent_update.system and agent_update.system != agent_state.system: + agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False) + + return agent_state + + @enforce_types + def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState: """ Update an existing agent. @@ -297,23 +293,6 @@ def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState agent.hard_delete(session) return agent_state - def get_total_message_count(self, agent_id: str, actor: PydanticUser) -> int: - """ - Get the total number of messages associated with a specific agent using the relationship. - - :param session: The SQLAlchemy session object. - :param agent_id: The ID of the agent. - :return: The total number of messages for the agent. - """ - with self.session_maker() as session: - return ( - session.query(func.count(MessageModel.id)) - .join(AgentModel.messages) - .filter(AgentModel.id == agent_id) - .filter(AgentModel.organization_id == actor.organization_id) - .scalar() - ) - # ====================================================================================================================== # In Context Messages Management # ====================================================================================================================== @@ -332,6 +311,61 @@ def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMess message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) + @enforce_types + def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState: + """Rebuilds the system message with the latest memory object and any shared memory block updates + + Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object + + Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages + """ + agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) + + curr_system_message = self.get_system_message( + agent_id=agent_id, actor=actor + ) # this is the system + memory bank, not just the system prompt + curr_system_message_openai = curr_system_message.to_openai_dict() + + # note: we only update the system prompt if the core memory is changed + # this means that the archival/recall memory statistics may be someout out of date + curr_memory_str = agent_state.memory.compile() + if curr_memory_str in curr_system_message_openai["content"] and not force: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.info(f"Memory hasn't changed, skipping system prompt rebuild") + return agent_state + + # If the memory didn't update, we probably don't want to update the timestamp inside + # For example, if we're doing a system prompt swap, this should probably be False + if update_timestamp: + memory_edit_timestamp = get_utc_time() + else: + # NOTE: a bit of a hack - we pull the timestamp from the message created_by + memory_edit_timestamp = curr_system_message.created_at + + # update memory (TODO: potentially update recall/archival stats separately) + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + ) + + diff = united_diff(curr_system_message_openai["content"], new_system_message_str) + if len(diff) > 0: # there was a diff + logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + # Swap the system message out (only if there is a diff) + message = PydanticMessage.dict_to_message( + agent_id=agent_id, + user_id=actor.id, + model=agent_state.llm_config.model, + openai_message_dict={"role": "system", "content": new_system_message_str}, + ) + message = self.message_manager.create_message(message, actor=actor) + message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system) + return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + else: + return agent_state + @enforce_types def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) @@ -356,19 +390,6 @@ def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_i message_ids += [m.id for m in messages] return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) - @enforce_types - def swap_system_message(self, new_system_message: str, agent_id: str, actor: PydanticUser): - agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) - message = PydanticMessage.dict_to_message( - agent_id=agent_id, - user_id=actor.id, - model=agent_state.llm_config.model, - openai_message_dict={"role": "system", "content": new_system_message}, - ) - message = self.message_manager.create_message(message, actor=actor) - message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system) - return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) - # ====================================================================================================================== # Source Management # ====================================================================================================================== diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index c7049d5f44..375dea69f8 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1,13 +1,17 @@ import datetime from typing import List, Literal, Optional +from letta import system from letta.constants import IN_CONTEXT_MEMORY_KEYWORD from letta.orm.agent import Agent as AgentModel from letta.orm.agents_tags import AgentsTags from letta.orm.errors import NoResultFound from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType +from letta.schemas.enums import MessageRole from letta.schemas.memory import Memory +from letta.schemas.message import Message, MessageCreate +from letta.schemas.user import User from letta.system import get_initial_boot_messages, get_login_event from letta.utils import get_local_time @@ -219,3 +223,27 @@ def initialize_message_sequence( ] return messages + + +def package_initial_message_sequence( + agent_id: str, initial_message_sequence: List[MessageCreate], model: str, actor: User +) -> List[Message]: + # create the agent object + init_messages = [] + for message_create in initial_message_sequence: + + if message_create.role == MessageRole.user: + packed_message = system.package_user_message( + user_message=message_create.text, + ) + elif message_create.role == MessageRole.system: + packed_message = system.package_system_message( + system_message=message_create.text, + ) + else: + raise ValueError(f"Invalid message role: {message_create.role}") + + init_messages.append( + Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model) + ) + return init_messages diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 8f1aa99c74..d796f0e8d9 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -78,7 +78,13 @@ def setup_agent( memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) agent_state = client.create_agent( - name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools, + name=agent_uuid, + llm_config=llm_config, + embedding_config=embedding_config, + memory=memory, + tool_ids=tool_ids, + tool_rules=tool_rules, + include_base_tools=include_base_tools, ) return agent_state @@ -105,12 +111,13 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet agent_state = setup_agent(client, filename) full_agent_state = client.get_agent(agent_state.id) + messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state, actor=client.user) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) response = create( llm_config=agent_state.llm_config, user_id=str(uuid.UUID(int=1)), # dummy user_id - messages=agent._messages, + messages=messages, functions=[t.json_schema for t in agent.agent_state.tools], ) diff --git a/tests/test_managers.py b/tests/test_managers.py index a7b385ea02..388d477c60 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -519,14 +519,14 @@ def test_create_agent_passed_in_initial_messages(server: SyncServer, default_use create_agent_request, actor=default_user, ) - assert server.agent_manager.get_total_message_count(agent_state.id, default_user) == 2 + assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 2 init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].text assert create_agent_request.memory_blocks[0].value in init_messages[0].text # Check that the second message is the passed in initial message seq assert create_agent_request.initial_message_sequence[0].role == init_messages[1].role - assert create_agent_request.initial_message_sequence[0].text == init_messages[1].text + assert create_agent_request.initial_message_sequence[0].text in init_messages[1].text def test_create_agent_default_initial_message(server: SyncServer, default_user, default_block): @@ -544,7 +544,7 @@ def test_create_agent_default_initial_message(server: SyncServer, default_user, create_agent_request, actor=default_user, ) - assert server.agent_manager.get_total_message_count(agent_state.id, default_user) == 4 + assert server.message_manager.size(agent_id=agent_state.id, actor=default_user) == 4 init_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) # Check that the system appears in the first initial message assert create_agent_request.system in init_messages[0].text diff --git a/tests/test_server.py b/tests/test_server.py index 2003c68880..b4631293e8 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -507,76 +507,9 @@ def test_get_archival_memory(server, user_id, agent_id): assert len(passage_none) == 0 -def test_agent_rethink_rewrite_retry(server, user_id, agent_id): - """Test the /rethink, /rewrite, and /retry commands in the CLI - - - "rethink" replaces the inner thoughts of the last assistant message - - "rewrite" replaces the text of the last assistant message - - "retry" retries the last assistant message - """ - actor = server.user_manager.get_user_or_default(user_id) - - # Send an initial message - server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?") - - # Grab the raw Agent object - letta_agent = server.load_agent(agent_id=agent_id, actor=actor) - assert letta_agent._messages[-1].role == MessageRole.tool - assert letta_agent._messages[-2].role == MessageRole.assistant - last_agent_message = letta_agent._messages[-2] - - # Try "rethink" - new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?" - assert last_agent_message.text is not None and last_agent_message.text != new_thought - server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor) - - # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id, actor=actor) - assert letta_agent._messages[-1].role == MessageRole.tool - assert letta_agent._messages[-2].role == MessageRole.assistant - last_agent_message = letta_agent._messages[-2] - assert last_agent_message.text == new_thought - - # Try "rewrite" - assert last_agent_message.tool_calls is not None - assert last_agent_message.tool_calls[0].function.name == "send_message" - assert last_agent_message.tool_calls[0].function.arguments is not None - args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) - assert "message" in args_json and args_json["message"] is not None and args_json["message"] != "" - - new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?" - server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor) - - # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id, actor=actor) - assert letta_agent._messages[-1].role == MessageRole.tool - assert letta_agent._messages[-2].role == MessageRole.assistant - last_agent_message = letta_agent._messages[-2] - args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) - assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text - - # Try retry - server.retry_agent_message(agent_id=agent_id, actor=actor) - - # Grab the agent object again (make sure it's live) - letta_agent = server.load_agent(agent_id=agent_id, actor=actor) - assert letta_agent._messages[-1].role == MessageRole.tool - assert letta_agent._messages[-2].role == MessageRole.assistant - last_agent_message = letta_agent._messages[-2] - - # Make sure the inner thoughts changed - assert last_agent_message.text is not None and last_agent_message.text != new_thought - - # Make sure the message changed - args_json = json.loads(last_agent_message.tool_calls[0].function.arguments) - print(args_json) - assert "message" in args_json and args_json["message"] is not None and args_json["message"] != new_text - - def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str): """Test that the context window overview fetch works""" - - overview = server.get_agent_context_window(user_id=user_id, agent_id=agent_id) + overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id)) assert overview is not None # Run some basic checks @@ -1142,10 +1075,10 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to # Add all the base tools request.tool_ids = [b.id for b in base_tools] - agent_state = server.update_agent(agent_state.id, request=request, actor=actor) + agent_state = server.agent_manager.update_agent(agent_state.id, request=request, actor=actor) assert len(agent_state.tools) == len(base_tools) # Remove one base tool request.tool_ids = [b.id for b in base_tools[:-2]] - agent_state = server.update_agent(agent_state.id, request=request, actor=actor) + agent_state = server.agent_manager.update_agent(agent_state.id, request=request, actor=actor) assert len(agent_state.tools) == len(base_tools) - 2 From 16265f119c802a2a84ede2b6ff2c527fbfe01409 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 12:04:22 -0800 Subject: [PATCH 04/11] Fix server test --- tests/test_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index b4631293e8..fcd4566022 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1075,10 +1075,10 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to # Add all the base tools request.tool_ids = [b.id for b in base_tools] - agent_state = server.agent_manager.update_agent(agent_state.id, request=request, actor=actor) + agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) assert len(agent_state.tools) == len(base_tools) # Remove one base tool request.tool_ids = [b.id for b in base_tools[:-2]] - agent_state = server.agent_manager.update_agent(agent_state.id, request=request, actor=actor) + agent_state = server.agent_manager.update_agent(agent_state.id, agent_update=request, actor=actor) assert len(agent_state.tools) == len(base_tools) - 2 From 0a3e3e74b44766652c7b783bf4a8427c6968b3bf Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 12:16:12 -0800 Subject: [PATCH 05/11] Fix some local client tests --- letta/client/client.py | 4 ++-- tests/helpers/endpoints_helper.py | 2 +- tests/test_client_legacy.py | 17 ++++++----------- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index e00701851a..e575979dc0 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2291,7 +2291,7 @@ def remove_tool_from_agent(self, agent_id: str, tool_id: str): agent_state (AgentState): State of the updated agent """ self.interface.clear() - agent_state = self.server.agent_manager.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, actor=self.user) + agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) return agent_state def rename_agent(self, agent_id: str, new_name: str): @@ -2426,7 +2426,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]: Returns: messages (List[Message]): List of in-context messages """ - return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user) + return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user) # agent interactions diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 5b0a30980f..87997aaf7a 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -107,7 +107,7 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet agent_state = setup_agent(client, filename) full_agent_state = client.get_agent(agent_state.id) - messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state, actor=client.user) + messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user) agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) response = create( diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 16dc1cd6b4..3d907fa373 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -18,20 +18,22 @@ from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, - ToolCallMessage, - ToolReturnMessage, - ReasoningMessage, LettaMessage, + ReasoningMessage, SystemMessage, + ToolCallMessage, + ToolReturnMessage, UserMessage, ) from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate from letta.schemas.usage import LettaUsageStatistics +from letta.services.helpers.agent_manager_helper import initialize_message_sequence from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager from letta.settings import model_settings +from letta.utils import get_utc_time from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -602,18 +604,11 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: If we pass in a non-empty list, we should get that sequence If we pass in an empty list, we should get an empty sequence """ - from letta.agent import initialize_message_sequence - from letta.utils import get_utc_time - # The reference initial message sequence: reference_init_messages = initialize_message_sequence( - model=agent.llm_config.model, - system=agent.system, - agent_id=agent.id, - memory=agent.memory, + agent_state=agent, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, - actor=default_user, ) # system, login message, send_message test, send_message receipt From a80f74b644d299f87aee5eb971863f5e26dc26be Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 13:03:50 -0800 Subject: [PATCH 06/11] wip fixing recall memory --- tests/test_local_client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_local_client.py b/tests/test_local_client.py index ea5d04e063..da5e533c5a 100644 --- a/tests/test_local_client.py +++ b/tests/test_local_client.py @@ -259,10 +259,10 @@ def test_recall_memory(client: LocalClient, agent: AgentState): assert exists # get in-context messages - messages = client.get_in_context_messages(agent.id) + in_context_messages = client.get_in_context_messages(agent.id) exists = False - for m in messages: - if message_str in str(m): + for m in in_context_messages: + if message_str in m.text: exists = True assert exists From 3556fd19cc519ff10f0e28a17be82b6a9b2a031e Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 14:22:35 -0800 Subject: [PATCH 07/11] Fix out of order error --- letta/agent.py | 13 ++++++++----- letta/server/server.py | 2 +- letta/services/message_manager.py | 11 +++++++++-- tests/test_summarize.py | 6 +++--- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index f326d826a4..e5beb3a353 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -205,7 +205,7 @@ def update_memory_if_change(self, new_memory: Memory) -> bool: # NOTE: don't do this since re-buildin the memory is handled at the start of the step # rebuild memory - this records the last edited timestamp of the memory # TODO: pass in update timestamp from block edit time - self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) + self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) return True return False @@ -565,7 +565,7 @@ def _handle_ai_response( # rebuild memory # TODO: @charles please check this - self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) + self.agent_state = self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user) # Update ToolRulesSolver state with last called function self.tool_rules_solver.update_tool_usage(function_name) @@ -597,6 +597,7 @@ def step( messages=next_input_message, **kwargs, ) + heartbeat_request = step_response.heartbeat_request function_failed = step_response.function_failed token_warning = step_response.in_context_memory_warning @@ -748,7 +749,9 @@ def inner_step( f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.agent_state.llm_config.context_window)}" ) - self.agent_manager.append_to_in_context_messages(all_new_messages, agent_id=self.agent_state.id, actor=self.user) + self.agent_state = self.agent_manager.append_to_in_context_messages( + all_new_messages, agent_id=self.agent_state.id, actor=self.user + ) return AgentStepResponse( messages=all_new_messages, @@ -917,9 +920,9 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, printd(f"Packaged into message: {summary_message}") prior_len = len(in_context_messages_openai) - self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user) + self.agent_state = self.agent_manager.trim_older_in_context_messages(cutoff, agent_id=self.agent_state.id, actor=self.user) packed_summary_message = {"role": "user", "content": summary_message} - self.agent_manager.prepend_to_in_context_messages( + self.agent_state = self.agent_manager.prepend_to_in_context_messages( messages=[ Message.dict_to_message( agent_id=self.agent_state.id, diff --git a/letta/server/server.py b/letta/server/server.py index 7dd709fedf..9d5dc28dbb 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -432,7 +432,7 @@ def _step( ) # save agent after step - save_agent(letta_agent) + # save_agent(letta_agent) except Exception as e: logger.error(f"Error in server._step: {e}") diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 7eedf95f09..c369aa2c75 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -30,11 +30,18 @@ def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[Py @enforce_types def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: - """Fetch a message by ID.""" + """Fetch messages by ID and return them in the requested order.""" with self.session_maker() as session: results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id) - return [msg.to_pydantic() for msg in results] + if len(results) != len(message_ids): + raise NoResultFound( + f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}" + ) + + # Sort results directly based on message_ids + result_dict = {msg.id: msg.to_pydantic() for msg in results} + return [result_dict[msg_id] for msg_id in message_ids] @enforce_types def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: diff --git a/tests/test_summarize.py b/tests/test_summarize.py index 8996841347..d798ff866c 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -121,9 +121,9 @@ def summarize_message_exists(messages: List[Message]) -> bool: # check if the summarize message is inside the messages assert isinstance(client, LocalClient), "Test only works with LocalClient" - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - print("SUMMARY", summarize_message_exists(agent_obj._messages)) - if summarize_message_exists(agent_obj._messages): + in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=client.user) + print("SUMMARY", summarize_message_exists(in_context_messages)) + if summarize_message_exists(in_context_messages): break if message_count > MAX_ATTEMPTS: From 0b5bebcf4db66d2d421a33baf4afff80ba60b56a Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 14:39:31 -0800 Subject: [PATCH 08/11] Fix summarizer tests --- .github/workflows/tests.yml | 1 - letta/services/agent_manager.py | 2 +- tests/integration_test_summarizer.py | 110 +++++++++++++++++++++- tests/test_summarize.py | 133 --------------------------- 4 files changed, 110 insertions(+), 136 deletions(-) delete mode 100644 tests/test_summarize.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 43e66727ab..e4c46c5e85 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,6 @@ jobs: - "test_memory.py" - "test_utils.py" - "test_stream_buffer_readers.py" - - "test_summarize.py" services: qdrant: image: qdrant/qdrant diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index fd0ac5bf31..711dbcbaa6 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -374,7 +374,7 @@ def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message - return self.set_in_context_messages(agent_id=agent_id, message_ids=[m.id for m in new_messages], actor=actor) + return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 9131797cc5..b4de0043b4 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -1,13 +1,16 @@ import json import os import uuid +from typing import List import pytest from letta import create_client from letta.agent import Agent +from letta.client.client import LocalClient from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.utils import cleanup @@ -16,6 +19,110 @@ LLM_CONFIG_DIR = "tests/configs/llm_model_configs" SUMMARY_KEY_PHRASE = "The following is a summary" +test_agent_name = f"test_client_{str(uuid.uuid4())}" + +# TODO: these tests should include looping through LLM providers, since behavior may vary across providers +# TODO: these tests should add function calls into the summarized message sequence:W + + +@pytest.fixture(scope="module") +def client(): + client = create_client() + # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + + yield client + + +@pytest.fixture(scope="module") +def agent_state(client): + # Generate uuid for agent name for this example + agent_state = client.create_agent(name=test_agent_name) + yield agent_state + + client.delete_agent(agent_state.id) + + +def test_summarize_messages_inplace(client, agent_state, mock_e2b_api_key_none): + """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" + # First send a few messages (5) + response = client.user_message( + agent_id=agent_state.id, + message="Hey, how's it going? What do you think about this whole shindig", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message( + agent_id=agent_state.id, + message="Any thoughts on the meaning of life?", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + response = client.user_message( + agent_id=agent_state.id, + message="Would you be surprised to learn that you're actually conversing with an AI right now?", + ).messages + assert response is not None and len(response) > 0 + print(f"test_summarize: response={response}") + + # reload agent object + agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + + agent_obj.summarize_messages_inplace() + + +def test_auto_summarize(client, mock_e2b_api_key_none): + """Test that the summarizer triggers by itself""" + small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") + small_context_llm_config.context_window = 4000 + + small_agent_state = client.create_agent( + name="small_context_agent", + llm_config=small_context_llm_config, + ) + + try: + + def summarize_message_exists(messages: List[Message]) -> bool: + for message in messages: + if message.text and "The following is a summary of the previous" in message.text: + print(f"Summarize message found after {message_count} messages: \n {message.text}") + return True + return False + + MAX_ATTEMPTS = 10 + message_count = 0 + while True: + + # send a message + response = client.user_message( + agent_id=small_agent_state.id, + message="What is the meaning of life?", + ) + message_count += 1 + + print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") + + # check if the summarize message is inside the messages + assert isinstance(client, LocalClient), "Test only works with LocalClient" + in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user) + print("SUMMARY", summarize_message_exists(in_context_messages)) + if summarize_message_exists(in_context_messages): + break + + if message_count > MAX_ATTEMPTS: + raise Exception(f"Summarize message not found after {message_count} messages") + + finally: + client.delete_agent(small_agent_state.id) + @pytest.mark.parametrize( "config_filename", @@ -69,4 +176,5 @@ def test_summarizer(config_filename): # Invoke a summarize letta_agent.summarize_messages_inplace(preserve_last_N_messages=False) - assert SUMMARY_KEY_PHRASE in letta_agent.messages[1]["content"], f"Test failed for config: {config_filename}" + in_context_messages = client.get_in_context_messages(agent_state.id) + assert SUMMARY_KEY_PHRASE in in_context_messages[1].text, f"Test failed for config: {config_filename}" diff --git a/tests/test_summarize.py b/tests/test_summarize.py deleted file mode 100644 index d798ff866c..0000000000 --- a/tests/test_summarize.py +++ /dev/null @@ -1,133 +0,0 @@ -import uuid -from typing import List - -from letta import create_client -from letta.client.client import LocalClient -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message - -from .utils import wipe_config - -# test_agent_id = "test_agent" -test_agent_name = f"test_client_{str(uuid.uuid4())}" -client = None -agent_obj = None - -# TODO: these tests should include looping through LLM providers, since behavior may vary across providers -# TODO: these tests should add function calls into the summarized message sequence:W - - -def create_test_agent(): - """Create a test agent that we can call functions on""" - wipe_config() - - global client - client = create_client() - - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - agent_state = client.create_agent( - name=test_agent_name, - ) - - global agent_obj - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - - -def test_summarize_messages_inplace(mock_e2b_api_key_none): - """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" - global client - global agent_obj - - if agent_obj is None: - create_test_agent() - - assert agent_obj is not None, "Run create_agent test first" - assert client is not None, "Run create_agent test first" - - # First send a few messages (5) - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Hey, how's it going? What do you think about this whole shindig", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Any thoughts on the meaning of life?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message(agent_id=agent_obj.agent_state.id, message="Does the number 42 ring a bell?").messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_obj.agent_state.id, - message="Would you be surprised to learn that you're actually conversing with an AI right now?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - # reload agent object - agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user) - - agent_obj.summarize_messages_inplace() - print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}") - # response = client.run_command(agent_id=agent_obj.agent_state.id, command="summarize") - - -def test_auto_summarize(mock_e2b_api_key_none): - """Test that the summarizer triggers by itself""" - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - small_context_llm_config = LLMConfig.default_config("gpt-4") - # default system prompt + funcs lead to ~2300 tokens, after one message it's at 2523 tokens - SMALL_CONTEXT_WINDOW = 4000 - small_context_llm_config.context_window = SMALL_CONTEXT_WINDOW - - agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, - ) - - try: - - def summarize_message_exists(messages: List[Message]) -> bool: - for message in messages: - if message.text and "The following is a summary of the previous" in message.text: - print(f"Summarize message found after {message_count} messages: \n {message.text}") - return True - return False - - MAX_ATTEMPTS = 5 - message_count = 0 - while True: - - # send a message - response = client.user_message( - agent_id=agent_state.id, - message="What is the meaning of life?", - ) - message_count += 1 - - print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") - - # check if the summarize message is inside the messages - assert isinstance(client, LocalClient), "Test only works with LocalClient" - in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=client.user) - print("SUMMARY", summarize_message_exists(in_context_messages)) - if summarize_message_exists(in_context_messages): - break - - if message_count > MAX_ATTEMPTS: - raise Exception(f"Summarize message not found after {message_count} messages") - - finally: - client.delete_agent(agent_state.id) From 4a20ce4d8840b0f3fd875a99b0acfe37e9d0cdd5 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 14:50:21 -0800 Subject: [PATCH 09/11] wip fix offline --- letta/offline_memory_agent.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/letta/offline_memory_agent.py b/letta/offline_memory_agent.py index f4eeec8a83..076e2dc07b 100644 --- a/letta/offline_memory_agent.py +++ b/letta/offline_memory_agent.py @@ -129,9 +129,8 @@ def __init__( # extras first_message_verify_mono: bool = False, max_memory_rethinks: int = 10, - initial_message_sequence: Optional[List[Message]] = None, ): - super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence) + super().__init__(interface, agent_state, user) self.first_message_verify_mono = first_message_verify_mono self.max_memory_rethinks = max_memory_rethinks From ea886172e4427ad0f55735cd0112d1fc64315dfb Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 15:02:58 -0800 Subject: [PATCH 10/11] Fix offline tests --- letta/chat_only_agent.py | 4 ++-- letta/services/agent_manager.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/letta/chat_only_agent.py b/letta/chat_only_agent.py index e340673eba..e5f431c550 100644 --- a/letta/chat_only_agent.py +++ b/letta/chat_only_agent.py @@ -63,8 +63,8 @@ def generate_offline_memory_agent(): conversation_persona_block_new = Block( name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000 ) - - recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :] + in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) + recent_convo = "".join([str(message) for message in in_context_messages[3:]])[-self.recent_convo_limit :] conversation_messages_block = Block( name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 711dbcbaa6..1ed4750eaa 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -331,7 +331,9 @@ def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, curr_memory_str = agent_state.memory.compile() if curr_memory_str in curr_system_message_openai["content"] and not force: # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.info(f"Memory hasn't changed, skipping system prompt rebuild") + logger.info( + f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" + ) return agent_state # If the memory didn't update, we probably don't want to update the timestamp inside From d8c9d3bb990f9cfe22d7a0ee05c454ab849aa340 Mon Sep 17 00:00:00 2001 From: Matt Zhou Date: Fri, 20 Dec 2024 15:12:54 -0800 Subject: [PATCH 11/11] Fix tool rules validation --- letta/agent.py | 18 +++++------------- letta/services/agent_manager.py | 5 +++++ letta/services/helpers/agent_manager_helper.py | 13 ++++++++++++- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index e5beb3a353..4f636938a6 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -17,7 +17,6 @@ MESSAGE_SUMMARY_WARNING_FRAC, O1_BASE_TOOLS, REQ_HEARTBEAT_MESSAGE, - STRUCTURED_OUTPUT_MODELS, ) from letta.errors import ContextWindowExceededError from letta.helpers import ToolRulesSolver @@ -47,7 +46,10 @@ from letta.schemas.user import User as PydanticUser from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager -from letta.services.helpers.agent_manager_helper import compile_memory_metadata_block +from letta.services.helpers.agent_manager_helper import ( + check_supports_structured_output, + compile_memory_metadata_block, +) from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager @@ -121,7 +123,7 @@ def __init__( # gpt-4, gpt-3.5-turbo, ... self.model = self.agent_state.llm_config.model - self.check_tool_rules() + self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules) # state managers self.block_manager = BlockManager() @@ -152,16 +154,6 @@ def __init__( # Load last function response from message history self.last_function_response = self.load_last_function_response() - def check_tool_rules(self): - if self.model not in STRUCTURED_OUTPUT_MODELS: - if len(self.tool_rules_solver.init_tool_rules) > 1: - raise ValueError( - "Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule." - ) - self.supports_structured_output = False - else: - self.supports_structured_output = True - def load_last_function_response(self): """Load the last function response from message history""" in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 1ed4750eaa..4e6b80ec5b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -29,6 +29,7 @@ from letta.services.helpers.agent_manager_helper import ( _process_relationship, _process_tags, + check_supports_structured_output, compile_system_message, derive_system_message, initialize_message_sequence, @@ -70,6 +71,10 @@ def create_agent( if not agent_create.llm_config or not agent_create.embedding_config: raise ValueError("llm_config and embedding_config are required") + # Check tool rules are valid + if agent_create.tool_rules: + check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules) + # create blocks (note: cannot be linked into the agent_id is created) block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original for create_block in agent_create.memory_blocks: diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 375dea69f8..2d7ac2805e 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -2,7 +2,8 @@ from typing import List, Literal, Optional from letta import system -from letta.constants import IN_CONTEXT_MEMORY_KEYWORD +from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS +from letta.helpers import ToolRulesSolver from letta.orm.agent import Agent as AgentModel from letta.orm.agents_tags import AgentsTags from letta.orm.errors import NoResultFound @@ -11,6 +12,7 @@ from letta.schemas.enums import MessageRole from letta.schemas.memory import Memory from letta.schemas.message import Message, MessageCreate +from letta.schemas.tool_rule import ToolRule from letta.schemas.user import User from letta.system import get_initial_boot_messages, get_login_event from letta.utils import get_local_time @@ -247,3 +249,12 @@ def package_initial_message_sequence( Message(role=message_create.role, text=packed_message, organization_id=actor.organization_id, agent_id=agent_id, model=model) ) return init_messages + + +def check_supports_structured_output(model: str, tool_rules: List[ToolRule]) -> bool: + if model not in STRUCTURED_OUTPUT_MODELS: + if len(ToolRulesSolver(tool_rules=tool_rules).init_tool_rules) > 1: + raise ValueError("Multiple initial tools are not supported for non-structured models. Please use only one initial tool rule.") + return False + else: + return True