Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Added ServerAPI imports if API import fails #12

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions chroma_migrate/import_chromadb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from typing import Dict
from chromadb.api import API

try:
from chromadb.api import API
except ImportError:
from chromadb.api import ServerAPI as API

from chromadb.api.models.Collection import Collection
from tqdm import tqdm
from more_itertools import chunked

from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata
from chroma_migrate.utils import (
migrate_embedding_metadata,
validate_collection_metadata,
)

CHUNK_SIZE = 1000


def migrate_from_remote_chroma(from_api: API, to_api: API):

print("Loading Existing Collections...")
from_collections = from_api.list_collections()

Expand All @@ -20,16 +28,20 @@ def migrate_from_remote_chroma(from_api: API, to_api: API):
print("Validating collection metadata...")
from_collection_to_metadata = {}
for collection in from_collections:
from_collection_to_metadata[collection.name] = validate_collection_metadata(collection.metadata)
from_collection_to_metadata[collection.name] = validate_collection_metadata(
collection.metadata
)

print("Migrating existing collections...")
from_collection_to_to_collection: Dict[str, Collection] = {}
total_embeddings = 0
for from_collection in from_collections:
to_collection = to_api.get_or_create_collection(from_collection.name, from_collection_to_metadata[from_collection.name])
to_collection = to_api.get_or_create_collection(
from_collection.name, from_collection_to_metadata[from_collection.name]
)
total_embeddings += from_collection.count()
from_collection_to_to_collection[from_collection.name] = to_collection

print("Migrating existing embeddings...")
with tqdm(total=total_embeddings) as pbar:
for from_collection in from_collections:
Expand All @@ -39,15 +51,23 @@ def migrate_from_remote_chroma(from_api: API, to_api: API):
absolute_position = 0
for chunk in chunked(data["ids"], chunk_size):
ids = chunk
embeddings = data["embeddings"][absolute_position:absolute_position+chunk_size]
metadatas = data["metadatas"][absolute_position:absolute_position+chunk_size]
embeddings = data["embeddings"][
absolute_position : absolute_position + chunk_size
]
metadatas = data["metadatas"][
absolute_position : absolute_position + chunk_size
]
for i, metadata in enumerate(metadatas):
metadatas[i] = migrate_embedding_metadata(metadata)

documents = data["documents"][absolute_position:absolute_position+chunk_size]
documents = data["documents"][
absolute_position : absolute_position + chunk_size
]
to_collection.add(ids, embeddings, metadatas, documents)
pbar.update(len(chunk))
absolute_position += len(chunk)

print(f"Migrated {len(from_collections)} collections and {total_embeddings} embeddings")
return True

print(
f"Migrated {len(from_collections)} collections and {total_embeddings} embeddings"
)
return True
36 changes: 25 additions & 11 deletions chroma_migrate/import_clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
from typing import Dict
import clickhouse_connect
from chromadb.api import API

try:
from chromadb.api import API
except ImportError:
from chromadb.api import ServerAPI as API

from chromadb.api.models.Collection import Collection
from tqdm import tqdm
import json

from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata
from chroma_migrate.utils import (
migrate_embedding_metadata,
validate_collection_metadata,
)


def migrate_from_clickhouse(api: API, host: str, port: int):
conn = clickhouse_connect.get_client(
host=host,
port=port,
)
host=host,
port=port,
)

print("Loading existing collections...")
# Read the collections from clickhouse
Expand All @@ -26,7 +34,9 @@ def migrate_from_clickhouse(api: API, host: str, port: int):
from_collection_to_metadata = {}
for collection in collections:
metadata = json.loads(collection[2])
from_collection_to_metadata[collection[1]] = validate_collection_metadata(metadata)
from_collection_to_metadata[collection[1]] = validate_collection_metadata(
metadata
)

# Create the collections in chromadb
print("Migrating existing collections...")
Expand All @@ -39,11 +49,15 @@ def migrate_from_clickhouse(api: API, host: str, port: int):
collection_uuid_to_chroma_collection[uuid] = coll

# -------------------------------------

# Add the embeddings to the collections
print("Migrating existing embeddings...")
with tqdm(total=conn.query("SELECT count(*) FROM embeddings").result_rows[0][0]) as pbar:
with conn.query_row_block_stream('SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings') as stream:
with tqdm(
total=conn.query("SELECT count(*) FROM embeddings").result_rows[0][0]
) as pbar:
with conn.query_row_block_stream(
"SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings"
) as stream:
for block in stream:
for record in block:
uuid = record[0]
Expand All @@ -56,6 +70,6 @@ def migrate_from_clickhouse(api: API, host: str, port: int):
collection = collection_uuid_to_chroma_collection[collection_uuid]
collection.add(id, embedding, metadata, document)
pbar.update(1)

print(f"Migrated {len(collections)} collections and {pbar.n} embeddings")
return True
return True
43 changes: 31 additions & 12 deletions chroma_migrate/import_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
from typing import Dict
import duckdb
from chromadb.api import API

try:
from chromadb.api import API
except ImportError:
from chromadb.api import ServerAPI as API
from chromadb.api.models.Collection import Collection
from tqdm import tqdm
import json

from chroma_migrate.utils import migrate_embedding_metadata, validate_collection_metadata
from chroma_migrate.utils import (
migrate_embedding_metadata,
validate_collection_metadata,
)


def migrate_from_duckdb(api: API, persist_directory: str):
# Load all the collections from the parquet files
Expand All @@ -17,16 +25,20 @@ def migrate_from_duckdb(api: API, persist_directory: str):

print("Loading Existing Collections...")
# Load the collections into duckdb
collections_parquet_path = os.path.join(persist_directory, "chroma-collections.parquet")
collections_parquet_path = os.path.join(
persist_directory, "chroma-collections.parquet"
)
conn.execute(
"CREATE TABLE collections (uuid STRING, name STRING, metadata STRING);"
)
"CREATE TABLE collections (uuid STRING, name STRING, metadata STRING);"
)
conn.execute(
f"INSERT INTO collections SELECT * FROM read_parquet('{collections_parquet_path}');"
)

# Read the collections from duckdb
collections = conn.execute("SELECT uuid, name, metadata FROM collections").fetchall()
collections = conn.execute(
"SELECT uuid, name, metadata FROM collections"
).fetchall()

if len(collections) == 0:
print("No collections found, exiting...")
Expand All @@ -48,10 +60,12 @@ def migrate_from_duckdb(api: API, persist_directory: str):
collection_uuid_to_chroma_collection[uuid] = coll

# -------------------------------------

# Load the embeddings into duckdb
print("Migrating existing embeddings...")
embeddings_parquet_path = os.path.join(persist_directory, "chroma-embeddings.parquet")
embeddings_parquet_path = os.path.join(
persist_directory, "chroma-embeddings.parquet"
)
conn.execute(
"CREATE TABLE embeddings (collection_uuid STRING, uuid STRING, embedding DOUBLE[], document STRING, id STRING, metadata STRING);"
)
Expand All @@ -60,7 +74,9 @@ def migrate_from_duckdb(api: API, persist_directory: str):
)

# Read the embeddings from duckdb
embeddings = conn.execute("SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings").fetch_df()
embeddings = conn.execute(
"SELECT uuid, collection_uuid, id, embedding, document, metadata FROM embeddings"
).fetch_df()

# Add the embeddings to the collections
for record in tqdm(embeddings.itertuples(index=False), total=embeddings.shape[0]):
Expand All @@ -69,7 +85,9 @@ def migrate_from_duckdb(api: API, persist_directory: str):
try:
metadata = json.loads(metadata)
except Exception as e:
print(f"Failed to load metadata for embedding {id} in collection {collection_uuid}. Malformed JSON")
print(
f"Failed to load metadata for embedding {id} in collection {collection_uuid}. Malformed JSON"
)
else:
metadata = None
if not isinstance(document, str):
Expand All @@ -78,7 +96,8 @@ def migrate_from_duckdb(api: API, persist_directory: str):
collection = collection_uuid_to_chroma_collection[collection_uuid]
collection.add(id, embedding, metadata, document)

print(f"Migrated {len(collections)} collections and {embeddings.shape[0]} embeddings")
print(
f"Migrated {len(collections)} collections and {embeddings.shape[0]} embeddings"
)

return True

22 changes: 17 additions & 5 deletions test_scripts/generate_test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
import random

import chromadb
from chromadb.api import API

try:
from chromadb.api import API
except ImportError:
from chromadb.api import ServerAPI as API
from chromadb.config import Settings


# Run this to populate a test clickhouse database
def gen():
def gen():
# Create a new API
api: API = chromadb.Client(Settings(chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="8000"))
api: API = chromadb.Client(
Settings(
chroma_api_impl="rest",
chroma_server_host="localhost",
chroma_server_http_port="8000",
)
)

# Create a random set of collections
for i in range(10):
Expand All @@ -22,6 +33,7 @@ def gen():
ids = [f"id_{i}" for i in range(N)]
collection.add(ids, embeddings, metadata, documents)
print(f"Added {N} embeddings to collection {collection.name}")



if __name__ == "__main__":
gen()
gen()
18 changes: 13 additions & 5 deletions test_scripts/generate_test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@
import random

import chromadb
from chromadb.api import API

try:
from chromadb.api import API
except ImportError:
from chromadb.api import ServerAPI as API
from chromadb.config import Settings


# Run this to generate a test duckdb database
def gen():
persist_directory_a = "./test_data_duckdb"

# Create a new API
api: API = chromadb.Client(Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory_a))
api: API = chromadb.Client(
Settings(chroma_db_impl="duckdb+parquet", persist_directory=persist_directory_a)
)

# Create a random set of collections
for i in range(10):
Expand All @@ -24,8 +31,9 @@ def gen():
ids = [f"id_{i}" for i in range(N)]
collection.add(ids, embeddings, metadata, documents)
print(f"Added {N} embeddings to collection {collection.name}")

api.persist()


if __name__ == "__main__":
gen()
gen()