diff --git a/chroma_migrate/import_chromadb.py b/chroma_migrate/import_chromadb.py index 6d86e4d..b01d163 100644 --- a/chroma_migrate/import_chromadb.py +++ b/chroma_migrate/import_chromadb.py @@ -1,15 +1,23 @@ from typing import Dict -from chromadb.api import API + +try: + from chromadb.api import API +except ImportError: + from chromadb.api import ServerAPI as API + from chromadb.api.models.Collection import Collection from tqdm import tqdm from more_itertools import chunked -from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata +from chroma_migrate.utils import ( + migrate_embedding_metadata, + validate_collection_metadata, +) CHUNK_SIZE = 1000 + def migrate_from_remote_chroma(from_api: API, to_api: API): - print("Loading Existing Collections...") from_collections = from_api.list_collections() @@ -20,16 +28,20 @@ def migrate_from_remote_chroma(from_api: API, to_api: API): print("Validating collection metadata...") from_collection_to_metadata = {} for collection in from_collections: - from_collection_to_metadata[collection.name] = validate_collection_metadata(collection.metadata) + from_collection_to_metadata[collection.name] = validate_collection_metadata( + collection.metadata + ) print("Migrating existing collections...") from_collection_to_to_collection: Dict[str, Collection] = {} total_embeddings = 0 for from_collection in from_collections: - to_collection = to_api.get_or_create_collection(from_collection.name, from_collection_to_metadata[from_collection.name]) + to_collection = to_api.get_or_create_collection( + from_collection.name, from_collection_to_metadata[from_collection.name] + ) total_embeddings += from_collection.count() from_collection_to_to_collection[from_collection.name] = to_collection - + print("Migrating existing embeddings...") with tqdm(total=total_embeddings) as pbar: for from_collection in from_collections: @@ -39,15 +51,23 @@ def migrate_from_remote_chroma(from_api: API, to_api: API): absolute_position = 0 for chunk in chunked(data["ids"], chunk_size): ids = chunk - embeddings = data["embeddings"][absolute_position:absolute_position+chunk_size] - metadatas = data["metadatas"][absolute_position:absolute_position+chunk_size] + embeddings = data["embeddings"][ + absolute_position : absolute_position + chunk_size + ] + metadatas = data["metadatas"][ + absolute_position : absolute_position + chunk_size + ] for i, metadata in enumerate(metadatas): metadatas[i] = migrate_embedding_metadata(metadata) - documents = data["documents"][absolute_position:absolute_position+chunk_size] + documents = data["documents"][ + absolute_position : absolute_position + chunk_size + ] to_collection.add(ids, embeddings, metadatas, documents) pbar.update(len(chunk)) absolute_position += len(chunk) - - print(f"Migrated {len(from_collections)} collections and {total_embeddings} embeddings") - return True \ No newline at end of file + + print( + f"Migrated {len(from_collections)} collections and {total_embeddings} embeddings" + ) + return True diff --git a/chroma_migrate/import_clickhouse.py b/chroma_migrate/import_clickhouse.py index dd58981..bc746c4 100644 --- a/chroma_migrate/import_clickhouse.py +++ b/chroma_migrate/import_clickhouse.py @@ -1,18 +1,26 @@ from typing import Dict import clickhouse_connect -from chromadb.api import API + +try: + from chromadb.api import API +except ImportError: + from chromadb.api import ServerAPI as API + from chromadb.api.models.Collection import Collection from tqdm import tqdm import json -from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata +from chroma_migrate.utils import ( + migrate_embedding_metadata, + validate_collection_metadata, +) def migrate_from_clickhouse(api: API, host: str, port: int): conn = clickhouse_connect.get_client( - host=host, - port=port, - ) + host=host, + port=port, + ) print("Loading existing collections...") # Read the collections from clickhouse @@ -26,7 +34,9 @@ def migrate_from_clickhouse(api: API, host: str, port: int): from_collection_to_metadata = {} for collection in collections: metadata = json.loads(collection[2]) - from_collection_to_metadata[collection[1]] = validate_collection_metadata(metadata) + from_collection_to_metadata[collection[1]] = validate_collection_metadata( + metadata + ) # Create the collections in chromadb print("Migrating existing collections...") @@ -39,11 +49,15 @@ def migrate_from_clickhouse(api: API, host: str, port: int): collection_uuid_to_chroma_collection[uuid] = coll # ------------------------------------- - + # Add the embeddings to the collections print("Migrating existing embeddings...") - with tqdm(total=conn.query("SELECT count(*) FROM embeddings").result_rows[0][0]) as pbar: - with conn.query_row_block_stream('SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings') as stream: + with tqdm( + total=conn.query("SELECT count(*) FROM embeddings").result_rows[0][0] + ) as pbar: + with conn.query_row_block_stream( + "SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings" + ) as stream: for block in stream: for record in block: uuid = record[0] @@ -56,6 +70,6 @@ def migrate_from_clickhouse(api: API, host: str, port: int): collection = collection_uuid_to_chroma_collection[collection_uuid] collection.add(id, embedding, metadata, document) pbar.update(1) - + print(f"Migrated {len(collections)} collections and {pbar.n} embeddings") - return True \ No newline at end of file + return True diff --git a/chroma_migrate/import_duckdb.py b/chroma_migrate/import_duckdb.py index c8871eb..d866796 100644 --- a/chroma_migrate/import_duckdb.py +++ b/chroma_migrate/import_duckdb.py @@ -1,12 +1,20 @@ import os from typing import Dict import duckdb -from chromadb.api import API + +try: + from chromadb.api import API +except ImportError: + from chromadb.api import ServerAPI as API from chromadb.api.models.Collection import Collection from tqdm import tqdm import json -from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata +from chroma_migrate.utils import ( + migrate_embedding_metadata, + validate_collection_metadata, +) + def migrate_from_duckdb(api: API, persist_directory: str): # Load all the collections from the parquet files @@ -17,16 +25,20 @@ def migrate_from_duckdb(api: API, persist_directory: str): print("Loading Existing Collections...") # Load the collections into duckdb - collections_parquet_path = os.path.join(persist_directory, "chroma-collections.parquet") + collections_parquet_path = os.path.join( + persist_directory, "chroma-collections.parquet" + ) conn.execute( - "CREATE TABLE collections (uuid STRING, name STRING, metadata STRING);" - ) + "CREATE TABLE collections (uuid STRING, name STRING, metadata STRING);" + ) conn.execute( f"INSERT INTO collections SELECT * FROM read_parquet('{collections_parquet_path}');" ) # Read the collections from duckdb - collections = conn.execute("SELECT uuid, name, metadata FROM collections").fetchall() + collections = conn.execute( + "SELECT uuid, name, metadata FROM collections" + ).fetchall() if len(collections) == 0: print("No collections found, exiting...") @@ -48,10 +60,12 @@ def migrate_from_duckdb(api: API, persist_directory: str): collection_uuid_to_chroma_collection[uuid] = coll # ------------------------------------- - + # Load the embeddings into duckdb print("Migrating existing embeddings...") - embeddings_parquet_path = os.path.join(persist_directory, "chroma-embeddings.parquet") + embeddings_parquet_path = os.path.join( + persist_directory, "chroma-embeddings.parquet" + ) conn.execute( "CREATE TABLE embeddings (collection_uuid STRING, uuid STRING, embedding DOUBLE[], document STRING, id STRING, metadata STRING);" ) @@ -60,7 +74,9 @@ def migrate_from_duckdb(api: API, persist_directory: str): ) # Read the embeddings from duckdb - embeddings = conn.execute("SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings").fetch_df() + embeddings = conn.execute( + "SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings" + ).fetch_df() # Add the embeddings to the collections for record in tqdm(embeddings.itertuples(index=False), total=embeddings.shape[0]): @@ -69,7 +85,9 @@ def migrate_from_duckdb(api: API, persist_directory: str): try: metadata = json.loads(metadata) except Exception as e: - print(f"Failed to load metadata for embedding {id} in collection {collection_uuid}. Malformed JSON") + print( + f"Failed to load metadata for embedding {id} in collection {collection_uuid}. Malformed JSON" + ) else: metadata = None if not isinstance(document, str): @@ -78,7 +96,8 @@ def migrate_from_duckdb(api: API, persist_directory: str): collection = collection_uuid_to_chroma_collection[collection_uuid] collection.add(id, embedding, metadata, document) - print(f"Migrated {len(collections)} collections and {embeddings.shape[0]} embeddings") + print( + f"Migrated {len(collections)} collections and {embeddings.shape[0]} embeddings" + ) return True - diff --git a/test_scripts/generate_test_clickhouse.py b/test_scripts/generate_test_clickhouse.py index a1e06e0..2712e30 100644 --- a/test_scripts/generate_test_clickhouse.py +++ b/test_scripts/generate_test_clickhouse.py @@ -2,13 +2,24 @@ import random import chromadb -from chromadb.api import API + +try: + from chromadb.api import API +except ImportError: + from chromadb.api import ServerAPI as API from chromadb.config import Settings + # Run this to populate a test clickhouse database -def gen(): +def gen(): # Create a new API - api: API = chromadb.Client(Settings(chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="8000")) + api: API = chromadb.Client( + Settings( + chroma_api_impl="rest", + chroma_server_host="localhost", + chroma_server_http_port="8000", + ) + ) # Create a random set of collections for i in range(10): @@ -22,6 +33,7 @@ def gen(): ids = [f"id_{i}" for i in range(N)] collection.add(ids, embeddings, metadata, documents) print(f"Added {N} embeddings to collection {collection.name}") - + + if __name__ == "__main__": - gen() \ No newline at end of file + gen() diff --git a/test_scripts/generate_test_duckdb.py b/test_scripts/generate_test_duckdb.py index a9444c0..0ec853c 100644 --- a/test_scripts/generate_test_duckdb.py +++ b/test_scripts/generate_test_duckdb.py @@ -2,15 +2,22 @@ import random import chromadb -from chromadb.api import API + +try: + from chromadb.api import API +except ImportError: + from chromadb.api import ServerAPI as API from chromadb.config import Settings + # Run this to generate a test duckdb database def gen(): persist_directory_a = "./test_data_duckdb" - + # Create a new API - api: API = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory_a)) + api: API = chromadb.Client( + Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory_a) + ) # Create a random set of collections for i in range(10): @@ -24,8 +31,9 @@ def gen(): ids = [f"id_{i}" for i in range(N)] collection.add(ids, embeddings, metadata, documents) print(f"Added {N} embeddings to collection {collection.name}") - + api.persist() + if __name__ == "__main__": - gen() \ No newline at end of file + gen()