diff --git a/chromadb/__init__.py b/chromadb/__init__.py index 0ff5244a80f..ad7d3d4f70b 100644 --- a/chromadb/__init__.py +++ b/chromadb/__init__.py @@ -2,8 +2,6 @@ import logging import sqlite3 import chromadb.config -from chromadb.telemetry.events import ClientStartEvent -from chromadb.telemetry import Telemetry from chromadb.config import Settings, System from chromadb.api import API from chromadb.api.models.Collection import Collection @@ -38,6 +36,8 @@ "QueryResult", "GetResult", ] +from chromadb.telemetry.events import ClientStartEvent +from chromadb.telemetry import Telemetry logger = logging.getLogger(__name__) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index 00002f46d27..fd2f08ec63b 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -29,7 +29,14 @@ validate_where_document, validate_batch, ) -from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent +from chromadb.telemetry.events import ( + CollectionAddEvent, + CollectionDeleteEvent, + CollectionGetEvent, + CollectionUpdateEvent, + CollectionQueryEvent, + ClientCreateCollectionEvent, +) import chromadb.types as t @@ -140,6 +147,13 @@ def create_collection( for segment in segments: self._sysdb.create_segment(segment) + self._telemetry_client.capture( + ClientCreateCollectionEvent( + collection_uuid=str(id), + embedding_function=embedding_function.__class__.__name__, + ) + ) + return Collection( client=self, id=id, @@ -263,7 +277,14 @@ def _add( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) + self._telemetry_client.capture( + CollectionAddEvent( + collection_uuid=str(collection_id), + add_amount=len(ids), + with_metadata=len(ids) if metadatas is not None else 0, + with_documents=len(ids) if documents is not None else 0, + ) + ) return True @override @@ -293,6 +314,16 @@ def _update( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) + self._telemetry_client.capture( + CollectionUpdateEvent( + collection_uuid=str(collection_id), + update_amount=len(ids), + with_embeddings=len(embeddings) if embeddings else 0, + with_metadata=len(metadatas) if metadatas else 0, + with_documents=len(documents) if documents else 0, + ) + ) + return True @override @@ -377,6 +408,16 @@ def _get( if "documents" in include: documents = [_doc(m) for m in metadatas] + self._telemetry_client.capture( + CollectionGetEvent( + collection_uuid=str(collection_id), + ids_count=len(ids) if ids else 0, + limit=limit if limit else 0, + include_metadata="metadatas" in include, + include_documents="documents" in include, + ) + ) + return GetResult( ids=[r["id"] for r in records], embeddings=[r["embedding"] for r in vectors] @@ -441,7 +482,9 @@ def _delete( self._producer.submit_embeddings(coll["topic"], records_to_submit) self._telemetry_client.capture( - CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) + CollectionDeleteEvent( + collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) + ) ) return ids_to_delete @@ -528,6 +571,19 @@ def _query( doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore + self._telemetry_client.capture( + CollectionQueryEvent( + collection_uuid=str(collection_id), + query_amount=len(query_embeddings), + n_results=n_results, + with_metadata_filter=where is not None, + with_document_filter=where_document is not None, + include_metadatas="metadatas" in include, + include_documents="documents" in include, + include_distances="distances" in include, + ) + ) + return QueryResult( ids=ids, distances=distances if distances else None, diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index db962549267..d20b8e5d71c 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from dataclasses import asdict, dataclass import os from typing import Callable, ClassVar, Dict, Any import uuid @@ -22,13 +21,30 @@ class ServerContext(Enum): FASTAPI = "FastAPI" -@dataclass class TelemetryEvent: - name: ClassVar[str] + max_batch_size: ClassVar[int] = 1 + batch_size: int + + def __init__(self, batch_size: int = 1): + self.batch_size = batch_size @property def properties(self) -> Dict[str, Any]: - return asdict(self) + return self.__dict__ + + @property + def name(self) -> str: + return self.__class__.__name__ + + # A batch key is used to determine whether two events can be batched together. + # If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be implemented. + # Otherwise they are ignored. + @property + def batch_key(self) -> str: + return self.name + + def batch(self, other: "TelemetryEvent") -> "TelemetryEvent": + raise NotImplementedError class RepeatedTelemetry: diff --git a/chromadb/telemetry/events.py b/chromadb/telemetry/events.py index 64c77574f9f..34c6264fcc9 100644 --- a/chromadb/telemetry/events.py +++ b/chromadb/telemetry/events.py @@ -1,27 +1,153 @@ -from dataclasses import dataclass -from typing import ClassVar +from typing import cast, ClassVar from chromadb.telemetry import TelemetryEvent +from chromadb.utils.embedding_functions import get_builtins -@dataclass class ClientStartEvent(TelemetryEvent): - name: ClassVar[str] = "client_start" + def __init__(self) -> None: + super().__init__() -@dataclass -class ServerStartEvent(TelemetryEvent): - name: ClassVar[str] = "server_start" +class ClientCreateCollectionEvent(TelemetryEvent): + collection_uuid: str + embedding_function: str + + def __init__(self, collection_uuid: str, embedding_function: str): + super().__init__() + self.collection_uuid = collection_uuid + + embedding_function_names = get_builtins() + + self.embedding_function = ( + embedding_function + if embedding_function in embedding_function_names + else "custom" + ) -@dataclass class CollectionAddEvent(TelemetryEvent): - name: ClassVar[str] = "collection_add" + max_batch_size: ClassVar[int] = 20 collection_uuid: str add_amount: int + with_documents: int + with_metadata: int + + def __init__( + self, + collection_uuid: str, + add_amount: int, + with_documents: int, + with_metadata: int, + batch_size: int = 1, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.add_amount = add_amount + self.with_documents = with_documents + self.with_metadata = with_metadata + self.batch_size = batch_size + + @property + def batch_key(self) -> str: + return self.collection_uuid + self.name + + def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent": + if not self.batch_key == other.batch_key: + raise ValueError("Cannot batch events") + other = cast(CollectionAddEvent, other) + total_amount = self.add_amount + other.add_amount + return CollectionAddEvent( + collection_uuid=self.collection_uuid, + add_amount=total_amount, + with_documents=self.with_documents + other.with_documents, + with_metadata=self.with_metadata + other.with_metadata, + batch_size=self.batch_size + other.batch_size, + ) + + +class CollectionUpdateEvent(TelemetryEvent): + collection_uuid: str + update_amount: int + with_embeddings: int + with_metadata: int + with_documents: int + + def __init__( + self, + collection_uuid: str, + update_amount: int, + with_embeddings: int, + with_metadata: int, + with_documents: int, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.update_amount = update_amount + self.with_embeddings = with_embeddings + self.with_metadata = with_metadata + self.with_documents = with_documents + + +class CollectionQueryEvent(TelemetryEvent): + collection_uuid: str + query_amount: int + with_metadata_filter: bool + with_document_filter: bool + n_results: int + include_metadatas: bool + include_documents: bool + include_distances: bool + + def __init__( + self, + collection_uuid: str, + query_amount: int, + with_metadata_filter: bool, + with_document_filter: bool, + n_results: int, + include_metadatas: bool, + include_documents: bool, + include_distances: bool, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.query_amount = query_amount + self.with_metadata_filter = with_metadata_filter + self.with_document_filter = with_document_filter + self.n_results = n_results + self.include_metadatas = include_metadatas + self.include_documents = include_documents + self.include_distances = include_distances + + +class CollectionGetEvent(TelemetryEvent): + collection_uuid: str + ids_count: int + limit: int + include_metadata: bool + include_documents: bool + + def __init__( + self, + collection_uuid: str, + ids_count: int, + limit: int, + include_metadata: bool, + include_documents: bool, + ): + super().__init__() + self.collection_uuid = collection_uuid + self.ids_count = ids_count + self.limit = limit + self.include_metadata = include_metadata + self.include_documents = include_documents -@dataclass class CollectionDeleteEvent(TelemetryEvent): - name: ClassVar[str] = "collection_delete" collection_uuid: str delete_amount: int + + def __init__(self, collection_uuid: str, delete_amount: int): + super().__init__() + self.collection_uuid = collection_uuid + self.delete_amount = delete_amount diff --git a/chromadb/telemetry/posthog.py b/chromadb/telemetry/posthog.py index a20e20dd257..184904531ef 100644 --- a/chromadb/telemetry/posthog.py +++ b/chromadb/telemetry/posthog.py @@ -1,6 +1,7 @@ import posthog import logging import sys +from typing import Any, Dict, Set from chromadb.config import System from chromadb.telemetry import Telemetry, TelemetryEvent from overrides import override @@ -21,10 +22,29 @@ def __init__(self, system: System): posthog_logger = logging.getLogger("posthog") # Silence posthog's logging posthog_logger.disabled = True + + self.batched_events: Dict[str, TelemetryEvent] = {} + self.seen_event_types: Set[Any] = set() + super().__init__(system) @override def capture(self, event: TelemetryEvent) -> None: + if event.max_batch_size == 1 or event.batch_key not in self.seen_event_types: + self.seen_event_types.add(event.batch_key) + self._direct_capture(event) + return + batch_key = event.batch_key + if batch_key not in self.batched_events: + self.batched_events[batch_key] = event + return + batched_event = self.batched_events[batch_key].batch(event) + self.batched_events[batch_key] = batched_event + if batched_event.batch_size >= batched_event.max_batch_size: + self._direct_capture(batched_event) + del self.batched_events[batch_key] + + def _direct_capture(self, event: TelemetryEvent) -> None: try: posthog.capture( self.user_id, diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 124213c365b..aaef53c01e2 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -9,6 +9,8 @@ import numpy as np import numpy.typing as npt import importlib +import inspect +import sys from typing import Optional try: @@ -43,7 +45,7 @@ def __init__( self._normalize_embeddings = normalize_embeddings def __call__(self, texts: Documents) -> Embeddings: - return self._model.encode( + return self._model.encode( # type: ignore list(texts), convert_to_numpy=True, normalize_embeddings=self._normalize_embeddings, @@ -224,10 +226,10 @@ def __init__( def __call__(self, texts: Documents) -> Embeddings: if self._instruction is None: - return self._model.encode(texts).tolist() + return self._model.encode(texts).tolist() # type: ignore texts_with_instructions = [[self._instruction, text] for text in texts] - return self._model.encode(texts_with_instructions).tolist() + return self._model.encode(texts_with_instructions).tolist() # type: ignore # In order to remove dependencies on sentence-transformers, which in turn depends on @@ -302,12 +304,12 @@ def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: # Use pytorches default epsilon for division by zero # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html - def _normalize(self, v: npt.NDArray) -> npt.NDArray: + def _normalize(self, v: npt.NDArray) -> npt.NDArray: # type: ignore norm = np.linalg.norm(v, axis=1) norm[norm == 0] = 1e-12 - return v / norm[:, np.newaxis] + return v / norm[:, np.newaxis] # type: ignore - def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: + def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: # type: ignore # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values self.tokenizer = cast(self.Tokenizer, self.tokenizer) # type: ignore self.model = cast(self.ort.InferenceSession, self.model) # type: ignore @@ -475,3 +477,15 @@ def __call__(self, texts: Documents) -> Embeddings: embeddings.append(response["predictions"]["embeddings"]["values"]) return embeddings + + +# List of all classes in this module +_classes = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) + if obj.__module__ == __name__ +] + + +def get_builtins() -> List[str]: + return _classes