Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Migrate FileMetadata to ORM #2028

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions alembic/versions/c85a3d07c028_move_files_to_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Move files to orm

Revision ID: c85a3d07c028
Revises: cda66b6cb0d6
Create Date: 2024-11-12 13:58:57.221081

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "c85a3d07c028"
down_revision: Union[str, None] = "cda66b6cb0d6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("files", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("files", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("files", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("files", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE files
SET organization_id = users.organization_id
FROM users
WHERE files.user_id = users.id
"""
)
op.alter_column("files", "organization_id", nullable=False)
op.create_foreign_key(None, "files", "organizations", ["organization_id"], ["id"])
op.create_foreign_key(None, "files", "sources", ["source_id"], ["id"])
op.drop_column("files", "user_id")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("files", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.drop_constraint(None, "files", type_="foreignkey")
op.drop_constraint(None, "files", type_="foreignkey")
op.drop_column("files", "organization_id")
op.drop_column("files", "_last_updated_by_id")
op.drop_column("files", "_created_by_id")
op.drop_column("files", "is_deleted")
op.drop_column("files", "updated_at")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from letta.agent_store.storage import StorageConnector, TableType
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.metadata import EmbeddingConfigColumn, FileMetadataModel, ToolCallColumn
from letta.metadata import EmbeddingConfigColumn, ToolCallColumn
from letta.orm.base import Base
from letta.orm.file import FileMetadata as FileMetadataModel

# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
from letta.schemas.message import Message
Expand Down
4 changes: 2 additions & 2 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2440,7 +2440,7 @@ def load_file_to_source(self, filename: str, source_id: str, blocking=True):
return job

def delete_file_from_source(self, source_id: str, file_id: str):
self.server.delete_file_from_source(source_id, file_id, user_id=self.user_id)
self.server.source_manager.delete_file(file_id, actor=self.user)

def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)
Expand Down Expand Up @@ -2561,7 +2561,7 @@ def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Opti
Returns:
files (List[FileMetadata]): List of files
"""
return self.server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
return self.server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=self.user)

def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
"""
Expand Down
11 changes: 3 additions & 8 deletions letta/data_sources/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from letta.schemas.file import FileMetadata
from letta.schemas.passage import Passage
from letta.schemas.source import Source
from letta.services.source_manager import SourceManager
from letta.utils import create_uuid_from_string


Expand Down Expand Up @@ -41,12 +42,7 @@ def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Itera
"""


def load_data(
connector: DataConnector,
source: Source,
passage_store: StorageConnector,
file_metadata_store: StorageConnector,
):
def load_data(connector: DataConnector, source: Source, passage_store: StorageConnector, source_manager: SourceManager, actor: "User"):
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config

Expand All @@ -60,7 +56,7 @@ def load_data(
file_count = 0
for file_metadata in connector.find_files(source):
file_count += 1
file_metadata_store.insert(file_metadata)
source_manager.create_file(file_metadata, actor)

# generate passages
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
Expand Down Expand Up @@ -155,7 +151,6 @@ def find_files(self, source: Source) -> Iterator[FileMetadata]:

for metadata in extract_metadata_from_files(files):
yield FileMetadata(
user_id=source.created_by_id,
source_id=source.id,
file_name=metadata.get("file_name"),
file_path=metadata.get("file_path"),
Expand Down
73 changes: 0 additions & 73 deletions letta/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Column,
DateTime,
Index,
Integer,
String,
TypeDecorator,
)
Expand All @@ -24,7 +23,6 @@
from letta.schemas.block import Block, Human, Persona
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
Expand All @@ -40,41 +38,6 @@
from letta.utils import enforce_types, get_utc_time, printd


class FileMetadataModel(Base):
__tablename__ = "files"
__table_args__ = {"extend_existing": True}

id = Column(String, primary_key=True, nullable=False)
user_id = Column(String, nullable=False)
# TODO: Investigate why this breaks during table creation due to FK
# source_id = Column(String, ForeignKey("sources.id"), nullable=False)
source_id = Column(String, nullable=False)
file_name = Column(String, nullable=True)
file_path = Column(String, nullable=True)
file_type = Column(String, nullable=True)
file_size = Column(Integer, nullable=True)
file_creation_date = Column(String, nullable=True)
file_last_modified_date = Column(String, nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())

def __repr__(self):
return f"<FileMetadata(id='{self.id}', source_id='{self.source_id}', file_name='{self.file_name}')>"

def to_record(self):
return FileMetadata(
id=self.id,
user_id=self.user_id,
source_id=self.source_id,
file_name=self.file_name,
file_path=self.file_path,
file_type=self.file_type,
file_size=self.file_size,
file_creation_date=self.file_creation_date,
file_last_modified_date=self.file_last_modified_date,
created_at=self.created_at,
)


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""

Expand Down Expand Up @@ -510,21 +473,6 @@ def update_or_create_block(self, block: Block):
session.add(BlockModel(**vars(block)))
session.commit()

@enforce_types
def delete_file_from_source(self, source_id: str, file_id: str, user_id: Optional[str]):
with self.session_maker() as session:
file_metadata = (
session.query(FileMetadataModel)
.filter(FileMetadataModel.source_id == source_id, FileMetadataModel.id == file_id, FileMetadataModel.user_id == user_id)
.first()
)

if file_metadata:
session.delete(file_metadata)
session.commit()

return file_metadata

@enforce_types
def delete_block(self, block_id: str):
with self.session_maker() as session:
Expand Down Expand Up @@ -653,27 +601,6 @@ def create_job(self, job: Job):
session.add(JobModel(**vars(job)))
session.commit()

@enforce_types
def list_files_from_source(self, source_id: str, limit: int, cursor: Optional[str]):
with self.session_maker() as session:
# Start with the basic query filtered by source_id
query = session.query(FileMetadataModel).filter(FileMetadataModel.source_id == source_id)

if cursor:
# Assuming cursor is the ID of the last file in the previous page
query = query.filter(FileMetadataModel.id > cursor)

# Order by ID or other ordering criteria to ensure correct pagination
query = query.order_by(FileMetadataModel.id)

# Limit the number of results returned
results = query.limit(limit).all()

# Convert the results to the required FileMetadata objects
files = [r.to_record() for r in results]

return files

def delete_job(self, job_id: str):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).delete()
Expand Down
1 change: 1 addition & 0 deletions letta/orm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from letta.orm.base import Base
from letta.orm.file import FileMetadata
from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.tool import Tool
Expand Down
29 changes: 29 additions & 0 deletions letta/orm/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import TYPE_CHECKING, Optional

from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin, SourceMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.file import FileMetadata as PydanticFileMetadata

if TYPE_CHECKING:
from letta.orm.organization import Organization


class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
"""Represents metadata for an uploaded file."""

__tablename__ = "files"
__pydantic_model__ = PydanticFileMetadata

file_name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The name of the file.")
file_path: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The file path on the system.")
file_type: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The type of the file.")
file_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, doc="The size of the file in bytes.")
file_creation_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The creation date of the file.")
file_last_modified_date: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The last modified date of the file.")

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
8 changes: 8 additions & 0 deletions letta/orm/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ class UserMixin(Base):
__abstract__ = True

user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))


class SourceMixin(Base):
"""Mixin for models (e.g. file) that belong to a source."""

__abstract__ = True

source_id: Mapped[str] = mapped_column(String, ForeignKey("sources.id"))
4 changes: 2 additions & 2 deletions letta/orm/organization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, List

from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.file import FileMetadata
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.organization import Organization as PydanticOrganization

Expand All @@ -18,14 +18,14 @@ class Organization(SqlalchemyBase):
__tablename__ = "organizations"
__pydantic_model__ = PydanticOrganization

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the organization.")

# relationships
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
Expand Down
3 changes: 2 additions & 1 deletion letta/orm/source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship
Expand Down Expand Up @@ -47,4 +47,5 @@ class Source(SqlalchemyBase, OrganizationMixin):

# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
# agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
1 change: 0 additions & 1 deletion letta/orm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
# An organization should not have multiple tools with the same name
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
Expand Down
2 changes: 0 additions & 2 deletions letta/orm/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING

from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.mixins import OrganizationMixin
Expand All @@ -17,7 +16,6 @@ class User(SqlalchemyBase, OrganizationMixin):
__tablename__ = "users"
__pydantic_model__ = PydanticUser

id: Mapped[str] = mapped_column(String, primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, doc="The display name of the user.")

# relationships
Expand Down
10 changes: 5 additions & 5 deletions letta/schemas/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pydantic import Field

from letta.schemas.letta_base import LettaBase
from letta.utils import get_utc_time


class FileMetadataBase(LettaBase):
Expand All @@ -17,15 +16,16 @@ class FileMetadata(FileMetadataBase):
"""Representation of a single FileMetadata"""

id: str = FileMetadataBase.generate_id_field()
user_id: str = Field(description="The unique identifier of the user associated with the document.")
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
file_name: Optional[str] = Field(None, description="The name of the file.")
file_path: Optional[str] = Field(None, description="The path to the file.")
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")

class Config:
extra = "allow"
# orm metadata, optional fields
created_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The creation date of the file.")
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, description="The update date of the file.")
is_deleted: bool = Field(False, description="Whether this file is deleted or not.")
6 changes: 4 additions & 2 deletions letta/server/rest_api/routers/v1/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,13 @@ def list_files_from_source(
limit: int = Query(1000, description="Number of files to return"),
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
List paginated files associated with a data source.
"""
return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
actor = server.get_user_or_default(user_id=user_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)


# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
Expand All @@ -219,7 +221,7 @@ def delete_file_from_source(
"""
actor = server.get_user_or_default(user_id=user_id)

deleted_file = server.delete_file_from_source(source_id=source_id, file_id=file_id, user_id=actor.id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")

Expand Down
Loading
Loading