Skip to content

Commit

Permalink
Add cascading deletes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Dec 21, 2024
1 parent c8352b5 commit 6bb64af
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Add cascading deletes for sources to agents
Revision ID: e78b4e82db30
Revises: d6632deac81d
Create Date: 2024-12-20 16:30:17.095888
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "e78b4e82db30"
down_revision: Union[str, None] = "d6632deac81d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("sources_agents_agent_id_fkey", "sources_agents", type_="foreignkey")
op.drop_constraint("sources_agents_source_id_fkey", "sources_agents", type_="foreignkey")
op.create_foreign_key(None, "sources_agents", "sources", ["source_id"], ["id"], ondelete="CASCADE")
op.create_foreign_key(None, "sources_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.drop_constraint(None, "sources_agents", type_="foreignkey")
op.create_foreign_key("sources_agents_source_id_fkey", "sources_agents", "sources", ["source_id"], ["id"])
op.create_foreign_key("sources_agents_agent_id_fkey", "sources_agents", "agents", ["agent_id"], ["id"])
# ### end Alembic commands ###
28 changes: 0 additions & 28 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
Expand All @@ -52,7 +51,6 @@
)
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
Expand Down Expand Up @@ -947,32 +945,6 @@ def migrate_embedding(self, embedding_config: EmbeddingConfig):
# TODO: recall memory
raise NotImplementedError()

def attach_source(
self,
user: PydanticUser,
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
):
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach
source_manager: SourceManager instance to verify source exists
agent_manager: AgentManager instance to manage agent-source relationship
"""
# Verify source exists and user has permission to access it
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"

# Use the agent_manager to create the relationship
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)

printd(
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)

def get_context_window(self) -> ContextWindowOverview:
"""Get the context window of the agent"""

Expand Down
13 changes: 10 additions & 3 deletions letta/orm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from letta.schemas.source import Source as PydanticSource

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.passage import SourcePassage
from letta.orm.agent import Agent


class Source(SqlalchemyBase, OrganizationMixin):
Expand All @@ -32,4 +32,11 @@ class Source(SqlalchemyBase, OrganizationMixin):
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
passages: Mapped[List["SourcePassage"]] = relationship("SourcePassage", back_populates="source", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
agents: Mapped[List["Agent"]] = relationship(
"Agent",
secondary="sources_agents",
back_populates="sources",
lazy="selectin",
cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted
passive_deletes=True, # Allows the database to handle deletion of orphaned rows
)
4 changes: 2 additions & 2 deletions letta/orm/sources_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ class SourcesAgents(Base):

__tablename__ = "sources_agents"

agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id", ondelete="CASCADE"), primary_key=True)
69 changes: 34 additions & 35 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools):
server.agent_manager.delete_agent(agent_state.id, actor=actor)


def test_error_on_nonexistent_agent(server, user_id, agent_id):
def test_error_on_nonexistent_agent(server, user, agent_id):
try:
fake_agent_id = str(uuid.uuid4())
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected
Expand All @@ -375,17 +375,17 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):


@pytest.mark.order(1)
def test_user_message_memory(server, user_id, agent_id):
def test_user_message_memory(server, user, agent_id):
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
server.user_message(user_id=user.id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed")
except ValueError as e:
# Error is expected
print(e)
except:
raise

server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")


@pytest.mark.order(3)
Expand Down Expand Up @@ -423,31 +423,30 @@ def test_save_archival_memory(server, user_id, agent_id):


@pytest.mark.order(4)
def test_user_message(server, user_id, agent_id):
def test_user_message(server, user, agent_id):
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")


@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
def test_get_recall_memory(server, org_id, user, agent_id):
# test recall memory cursor pagination
actor = server.user_manager.get_user_or_default(user_id=user_id)
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
actor = user
messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000)
messages_3[-1].id
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1

# test in-context message ids
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids

message_ids = [m.id for m in messages_3]
Expand All @@ -456,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):


@pytest.mark.order(6)
def test_get_archival_memory(server, user_id, agent_id):
def test_get_archival_memory(server, user, agent_id):
# test archival memory cursor pagination
user = server.user_manager.get_user_by_id(user_id=user_id)
actor = user

# List latest 2 passages
passages_1 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
limit=2,
Expand All @@ -472,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
cursor=cursor1,
Expand All @@ -481,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
Expand All @@ -494,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id):
earliest = passages_2[-1]

# test archival memory
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0


def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
def test_get_context_window_overview(server: SyncServer, user, agent_id):
"""Test that the context window overview fetch works"""
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
assert overview is not None

# Run some basic checks
Expand Down Expand Up @@ -544,15 +543,15 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
)


def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
agent_state = server.create_agent(
request=CreateAgent(
name="nonexistent_tools_agent",
memory_blocks=[],
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=server.user_manager.get_user_or_default(user_id),
actor=user,
)

# create another user in the same org
Expand All @@ -564,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):

def _test_get_messages_letta_format(
server,
user_id,
user,
agent_id,
reverse=False,
):
"""Test mapping between messages and letta_messages with reverse=False."""

messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
Expand All @@ -580,7 +579,7 @@ def _test_get_messages_letta_format(
assert all(isinstance(m, Message) for m in messages)

letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
Expand Down Expand Up @@ -673,10 +672,10 @@ def _test_get_messages_letta_format(
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")


def test_get_messages_letta_format(server, user_id, agent_id):
def test_get_messages_letta_format(server, user, agent_id):
# for reverse in [False, True]:
for reverse in [False]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)


EXAMPLE_TOOL_SOURCE = '''
Expand Down Expand Up @@ -823,9 +822,9 @@ def test_composio_client_simple(server):
assert len(actions) > 0


def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
actor = user
# create agent
agent_state = server.create_agent(
request=CreateAgent(
Expand All @@ -846,7 +845,7 @@ def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]:

# At this stage, there should only be 1 system message inside of recall storage
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_state.id,
limit=1000,
# reverse=reverse,
Expand All @@ -868,7 +867,7 @@ def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]:
assert num_system_messages == 1, (num_system_messages, all_messages)

# Assuming core memory append actually ran correctly, at this point there should be 2 messages
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")

# At this stage, there should be 2 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
Expand Down

0 comments on commit 6bb64af

Please sign in to comment.