diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 23f8e3413ea..50b3e6c3446 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -347,13 +347,13 @@ def delete_collection( ) if existing: - self._sysdb.delete_collection( - existing[0]["id"], tenant=tenant, database=database - ) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) if existing and existing[0]["id"] in self._collection_cache: del self._collection_cache[existing[0]["id"]] + self._sysdb.delete_collection( + existing[0]["id"], tenant=tenant, database=database + ) else: raise ValueError(f"Collection {name} does not exist.") diff --git a/chromadb/db/impl/sqlite.py b/chromadb/db/impl/sqlite.py index 8549bc36207..6d21c9ae189 100644 --- a/chromadb/db/impl/sqlite.py +++ b/chromadb/db/impl/sqlite.py @@ -1,4 +1,7 @@ +from pathlib import Path + from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool +from chromadb.db.impl.sqlite_utils import get_drop_order from chromadb.db.migrations import MigratableDB, Migration from chromadb.config import System, Settings import chromadb.db.base as base @@ -34,6 +37,7 @@ def __init__(self, conn_pool: Pool, stack: local): @override def __enter__(self) -> base.Cursor: if len(self._tx_stack.stack) == 0: + self._conn.execute("PRAGMA foreign_keys = ON") self._conn.execute("PRAGMA case_sensitive_like = ON") self._conn.execute("BEGIN;") self._tx_stack.stack.append(self) @@ -139,15 +143,16 @@ def reset_state(self) -> None: "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." ) with self.tx() as cur: - # Drop all tables cur.execute( """ SELECT name FROM sqlite_master WHERE type='table' """ ) - for row in cur.fetchall(): - cur.execute(f"DROP TABLE IF EXISTS {row[0]}") + drop_statement = "" + for t in get_drop_order(cur): + drop_statement += f"DROP TABLE IF EXISTS {t};\n" + cur.executescript(drop_statement) self._conn_pool.close() self.start() super().reset_state() @@ -218,7 +223,16 @@ def db_migrations(self, dir: Traversable) -> Sequence[Migration]: @override def apply_migration(self, cur: base.Cursor, migration: Migration) -> None: - cur.executescript(migration["sql"]) + if any(item.name == f".{migration['filename']}.disable_fk" + for traversable in self.migration_dirs() + for item in traversable.iterdir() if item.is_file()): + cur.executescript( + "PRAGMA foreign_keys = OFF;\n" + + migration["sql"] + + ";\nPRAGMA foreign_keys = ON;" + ) + else: + cur.executescript(migration["sql"]) cur.execute( """ INSERT INTO migrations (dir, version, filename, sql, hash) diff --git a/chromadb/db/impl/sqlite_utils.py b/chromadb/db/impl/sqlite_utils.py new file mode 100644 index 00000000000..66265f8c8dc --- /dev/null +++ b/chromadb/db/impl/sqlite_utils.py @@ -0,0 +1,44 @@ +from collections import defaultdict, deque +from graphlib import TopologicalSorter +from typing import List, Dict + +from chromadb.db.base import Cursor + + +def fetch_tables(cursor: Cursor) -> List[str]: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + return [row[0] for row in cursor.fetchall()] + + +def fetch_foreign_keys(cursor: Cursor, table_name: str) -> List[str]: + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + return [row[2] for row in cursor.fetchall()] # Table being referenced + + +def build_dependency_graph(tables: List[str], cursor: Cursor) -> Dict[str, List[str]]: + graph = defaultdict(list) + for table in tables: + foreign_keys = fetch_foreign_keys(cursor, table) + for fk_table in foreign_keys: + graph[table].append(fk_table) + if not foreign_keys and table not in graph.keys(): + graph[table] = [] + + return graph + + +def topological_sort(graph: Dict[str, List[str]]) -> List[str]: + ts = TopologicalSorter(graph) + # Reverse the order since TopologicalSorter gives the order of dependencies, + # but we want to drop tables in reverse dependency order + return list(ts.static_order())[::-1] + + +def get_drop_order(cursor: Cursor) -> List[str]: + tables = fetch_tables(cursor) + filtered_tables = [ + table for table in tables if not table.startswith("embedding_fulltext_search_") + ] + graph = build_dependency_graph(filtered_tables, cursor) + drop_order = topological_sort(graph) + return drop_order diff --git a/chromadb/migrations/metadb/00006-em-fk.sqlite.sql b/chromadb/migrations/metadb/00006-em-fk.sqlite.sql new file mode 100644 index 00000000000..0fc9c46cb86 --- /dev/null +++ b/chromadb/migrations/metadb/00006-em-fk.sqlite.sql @@ -0,0 +1,22 @@ +-- Disable foreign key constraints to us to update the segments table +PRAGMA foreign_keys = OFF; + +CREATE TABLE embedding_metadata_temp ( + id INTEGER REFERENCES embeddings(id) ON DELETE CASCADE NOT NULL, + key TEXT NOT NULL, + string_value TEXT, + int_value INTEGER, + float_value REAL, + bool_value INTEGER, + PRIMARY KEY (id, key) +); + +INSERT INTO embedding_metadata_temp +SELECT id, key, string_value, int_value, float_value, bool_value +FROM embedding_metadata; + +DROP TABLE embedding_metadata; + +ALTER TABLE embedding_metadata_temp RENAME TO embedding_metadata; + +PRAGMA foreign_keys = ON; diff --git a/chromadb/migrations/sysdb/.00004-tenants-databases.sqlite.sql.disable_fk b/chromadb/migrations/sysdb/.00004-tenants-databases.sqlite.sql.disable_fk new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/migrations/sysdb/.00006-segments-fk.sqlite.sql.disable_fk b/chromadb/migrations/sysdb/.00006-segments-fk.sqlite.sql.disable_fk new file mode 100644 index 00000000000..e69de29bb2d diff --git a/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql b/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql new file mode 100644 index 00000000000..29a72f845bf --- /dev/null +++ b/chromadb/migrations/sysdb/00006-segments-fk.sqlite.sql @@ -0,0 +1,13 @@ +CREATE TABLE segments_temp ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + scope TEXT NOT NULL, + topic TEXT, + collection TEXT REFERENCES collections(id) +); + +INSERT INTO segments_temp SELECT * FROM segments; + +DROP TABLE segments; + +ALTER TABLE segments_temp RENAME TO segments; diff --git a/chromadb/test/property/test_segment_manager.py b/chromadb/test/property/test_segment_manager.py index f73fedeccf2..f728a48853b 100644 --- a/chromadb/test/property/test_segment_manager.py +++ b/chromadb/test/property/test_segment_manager.py @@ -98,6 +98,10 @@ def initialize(self) -> None: def create_segment( self, coll: strategies.Collection ) -> MultipleResults[strategies.Collection]: + coll.name = f"{coll.name}_{uuid.uuid4()}" + self.sysdb.create_collection( + name=coll.name, id=coll.id, metadata=coll.metadata, dimension=coll.dimension + ) segments = self.segment_manager.create_segments(asdict(coll)) for segment in segments: self.sysdb.create_segment(segment) diff --git a/chromadb/test/test_client.py b/chromadb/test/test_client.py index 34dd2df1412..041eff064e0 100644 --- a/chromadb/test/test_client.py +++ b/chromadb/test/test_client.py @@ -1,3 +1,4 @@ +import shutil from typing import Generator from unittest.mock import patch import chromadb @@ -17,6 +18,7 @@ def ephemeral_api() -> Generator[ClientAPI, None, None]: @pytest.fixture def persistent_api() -> Generator[ClientAPI, None, None]: + shutil.rmtree(tempfile.gettempdir() + "/test_server", ignore_errors=True) client = chromadb.PersistentClient( path=tempfile.gettempdir() + "/test_server", )