Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Remove in-memory _messages field on Agent #2295

Merged
merged 12 commits into from
Dec 20, 2024
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
- "test_memory.py"
- "test_utils.py"
- "test_stream_buffer_readers.py"
- "test_summarize.py"
services:
qdrant:
image: qdrant/qdrant
Expand Down
708 changes: 73 additions & 635 deletions letta/agent.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions letta/chat_only_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def generate_offline_memory_agent():
conversation_persona_block_new = Block(
name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000
)

recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :]
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
recent_convo = "".join([str(message) for message in in_context_messages[3:]])[-self.recent_convo_limit :]
conversation_messages_block = Block(
name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit
)
Expand Down
12 changes: 6 additions & 6 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,7 +2234,7 @@ def update_agent(
"""
# TODO: add the abilitty to reset linked block_ids
self.interface.clear()
agent_state = self.server.update_agent(
agent_state = self.server.agent_manager.update_agent(
agent_id,
UpdateAgent(
name=name,
Expand Down Expand Up @@ -2262,7 +2262,7 @@ def get_tools_from_agent(self, agent_id: str) -> List[Tool]:
List[Tool]: A list of Tool objs
"""
self.interface.clear()
return self.server.get_tools_from_agent(agent_id=agent_id, user_id=self.user_id)
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools

def add_tool_to_agent(self, agent_id: str, tool_id: str):
"""
Expand All @@ -2276,7 +2276,7 @@ def add_tool_to_agent(self, agent_id: str, tool_id: str):
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state

def remove_tool_from_agent(self, agent_id: str, tool_id: str):
Expand All @@ -2291,7 +2291,7 @@ def remove_tool_from_agent(self, agent_id: str, tool_id: str):
agent_state (AgentState): State of the updated agent
"""
self.interface.clear()
agent_state = self.server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=self.user_id)
agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user)
return agent_state

def rename_agent(self, agent_id: str, new_name: str):
Expand Down Expand Up @@ -2426,7 +2426,7 @@ def get_in_context_messages(self, agent_id: str) -> List[Message]:
Returns:
messages (List[Message]): List of in-context messages
"""
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user)

# agent interactions

Expand Down Expand Up @@ -3075,7 +3075,7 @@ def delete_archival_memory(self, agent_id: str, memory_id: str):
agent_id (str): ID of the agent
memory_id (str): ID of the memory
"""
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
self.server.delete_archival_memory(memory_id=memory_id, actor=self.user)

def get_archival_memory(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
Expand Down
76 changes: 0 additions & 76 deletions letta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,46 +194,6 @@ def run_agent_loop(
print(f"Current model: {letta_agent.agent_state.llm_config.model}")
continue

elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
pop_amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 3
try:
popped_messages = letta_agent.pop_message(count=pop_amount)
except ValueError as e:
print(f"Error popping messages: {e}")
continue

elif user_input.lower() == "/retry":
print(f"Retrying for another answer...")
try:
letta_agent.retry_message()
except Exception as e:
print(f"Error retrying message: {e}")
continue

elif user_input.lower() == "/rethink" or user_input.lower().startswith("/rethink "):
if len(user_input) < len("/rethink "):
print("Missing text after the command")
continue
try:
letta_agent.rethink_message(new_thought=user_input[len("/rethink ") :].strip())
except Exception as e:
print(f"Error rethinking message: {e}")
continue

elif user_input.lower() == "/rewrite" or user_input.lower().startswith("/rewrite "):
if len(user_input) < len("/rewrite "):
print("Missing text after the command")
continue

text = user_input[len("/rewrite ") :].strip()
try:
letta_agent.rewrite_message(new_text=text)
except Exception as e:
print(f"Error rewriting message: {e}")
continue

elif user_input.lower() == "/summarize":
try:
letta_agent.summarize_messages_inplace()
Expand Down Expand Up @@ -319,42 +279,6 @@ def run_agent_loop(
questionary.print(cmd, "bold")
questionary.print(f" {desc}")
continue

elif user_input.lower().startswith("/systemswap"):
if len(user_input) < len("/systemswap "):
print("Missing new system prompt after the command")
continue
old_system_prompt = letta_agent.system
new_system_prompt = user_input[len("/systemswap ") :].strip()

# Show warning and prompts to user
typer.secho(
"\nWARNING: You are about to change the system prompt.",
# fg=typer.colors.BRIGHT_YELLOW,
bold=True,
)
typer.secho(
f"\nOld system prompt:\n{old_system_prompt}",
fg=typer.colors.RED,
bold=True,
)
typer.secho(
f"\nNew system prompt:\n{new_system_prompt}",
fg=typer.colors.GREEN,
bold=True,
)

# Ask for confirmation
confirm = questionary.confirm("Do you want to proceed with the swap?").ask()

if confirm:
letta_agent.update_system_prompt(new_system_prompt=new_system_prompt)
print("System prompt updated successfully.")
else:
print("System prompt swap cancelled.")

continue

else:
print(f"Unrecognized command: {user_input}")
continue
Expand Down
3 changes: 1 addition & 2 deletions letta/offline_memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ def __init__(
# extras
first_message_verify_mono: bool = False,
max_memory_rethinks: int = 10,
initial_message_sequence: Optional[List[Message]] = None,
):
super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence)
super().__init__(interface, agent_state, user)
self.first_message_verify_mono = first_message_verify_mono
self.max_memory_rethinks = max_memory_rethinks

Expand Down
17 changes: 9 additions & 8 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_agent_context_window(
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)

return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
return server.get_agent_context_window(agent_id=agent_id, actor=actor)


class CreateAgentRequest(CreateAgent):
Expand Down Expand Up @@ -138,7 +138,7 @@ def update_agent(
):
"""Update an exsiting agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent(agent_id, update_agent, actor=actor)
return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor)


@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
Expand All @@ -149,7 +149,7 @@ def get_tools_from_agent(
):
"""Get tools from an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).tools


@router.patch("/{agent_id}/add-tool/{tool_id}", response_model=AgentState, operation_id="add_tool_to_agent")
Expand All @@ -161,7 +161,7 @@ def add_tool_to_agent(
):
"""Add tools to an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, user_id=actor)


@router.patch("/{agent_id}/remove-tool/{tool_id}", response_model=AgentState, operation_id="remove_tool_from_agent")
Expand All @@ -173,7 +173,7 @@ def remove_tool_from_agent(
):
"""Add tools to an existing agent"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)


@router.get("/{agent_id}", response_model=AgentState, operation_id="get_agent")
Expand Down Expand Up @@ -232,7 +232,7 @@ def get_agent_in_context_messages(
Retrieve the messages in the context of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_in_context_messages(agent_id=agent_id, actor=actor)
return server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)


# TODO: remove? can also get with agent blocks
Expand Down Expand Up @@ -429,7 +429,7 @@ def delete_agent_archival_memory(
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)

server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor)
server.delete_archival_memory(memory_id=memory_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})


Expand Down Expand Up @@ -479,8 +479,9 @@ def update_message(
"""
Update the details of a message associated with an agent.
"""
# TODO: Get rid of agent_id here, it's not really relevant
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor)
return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)


@router.post(
Expand Down
Loading
Loading