Skip to content

Commit

Permalink
Merge branch 'main' into feature/aws-terraform-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov authored Sep 23, 2023
2 parents ffddd63 + c7a0414 commit 4cb6cc4
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 26 deletions.
4 changes: 2 additions & 2 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import logging
import sqlite3
import chromadb.config
from chromadb.telemetry.events import ClientStartEvent
from chromadb.telemetry import Telemetry
from chromadb.config import Settings, System
from chromadb.api import API
from chromadb.api.models.Collection import Collection
Expand Down Expand Up @@ -38,6 +36,8 @@
"QueryResult",
"GetResult",
]
from chromadb.telemetry.events import ClientStartEvent
from chromadb.telemetry import Telemetry


logger = logging.getLogger(__name__)
Expand Down
62 changes: 59 additions & 3 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
validate_where_document,
validate_batch,
)
from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent
from chromadb.telemetry.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)

import chromadb.types as t

Expand Down Expand Up @@ -140,6 +147,13 @@ def create_collection(
for segment in segments:
self._sysdb.create_segment(segment)

self._telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
embedding_function=embedding_function.__class__.__name__,
)
)

return Collection(
client=self,
id=id,
Expand Down Expand Up @@ -263,7 +277,14 @@ def _add(
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids)))
self._telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
)
)
return True

@override
Expand Down Expand Up @@ -293,6 +314,16 @@ def _update(
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(
CollectionUpdateEvent(
collection_uuid=str(collection_id),
update_amount=len(ids),
with_embeddings=len(embeddings) if embeddings else 0,
with_metadata=len(metadatas) if metadatas else 0,
with_documents=len(documents) if documents else 0,
)
)

return True

@override
Expand Down Expand Up @@ -377,6 +408,16 @@ def _get(
if "documents" in include:
documents = [_doc(m) for m in metadatas]

self._telemetry_client.capture(
CollectionGetEvent(
collection_uuid=str(collection_id),
ids_count=len(ids) if ids else 0,
limit=limit if limit else 0,
include_metadata="metadatas" in include,
include_documents="documents" in include,
)
)

return GetResult(
ids=[r["id"] for r in records],
embeddings=[r["embedding"] for r in vectors]
Expand Down Expand Up @@ -441,7 +482,9 @@ def _delete(
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(
CollectionDeleteEvent(str(collection_id), len(ids_to_delete))
CollectionDeleteEvent(
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
)
)
return ids_to_delete

Expand Down Expand Up @@ -528,6 +571,19 @@ def _query(
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore

self._telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=len(query_embeddings),
n_results=n_results,
with_metadata_filter=where is not None,
with_document_filter=where_document is not None,
include_metadatas="metadatas" in include,
include_documents="documents" in include,
include_distances="distances" in include,
)
)

return QueryResult(
ids=ids,
distances=distances if distances else None,
Expand Down
24 changes: 20 additions & 4 deletions chromadb/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from dataclasses import asdict, dataclass
import os
from typing import Callable, ClassVar, Dict, Any
import uuid
Expand All @@ -22,13 +21,30 @@ class ServerContext(Enum):
FASTAPI = "FastAPI"


@dataclass
class TelemetryEvent:
name: ClassVar[str]
max_batch_size: ClassVar[int] = 1
batch_size: int

def __init__(self, batch_size: int = 1):
self.batch_size = batch_size

@property
def properties(self) -> Dict[str, Any]:
return asdict(self)
return self.__dict__

@property
def name(self) -> str:
return self.__class__.__name__

# A batch key is used to determine whether two events can be batched together.
# If a TelemetryEvent's max_batch_size > 1, batch_key() and batch() MUST be implemented.
# Otherwise they are ignored.
@property
def batch_key(self) -> str:
return self.name

def batch(self, other: "TelemetryEvent") -> "TelemetryEvent":
raise NotImplementedError


class RepeatedTelemetry:
Expand Down
148 changes: 137 additions & 11 deletions chromadb/telemetry/events.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,153 @@
from dataclasses import dataclass
from typing import ClassVar
from typing import cast, ClassVar
from chromadb.telemetry import TelemetryEvent
from chromadb.utils.embedding_functions import get_builtins


@dataclass
class ClientStartEvent(TelemetryEvent):
name: ClassVar[str] = "client_start"
def __init__(self) -> None:
super().__init__()


@dataclass
class ServerStartEvent(TelemetryEvent):
name: ClassVar[str] = "server_start"
class ClientCreateCollectionEvent(TelemetryEvent):
collection_uuid: str
embedding_function: str

def __init__(self, collection_uuid: str, embedding_function: str):
super().__init__()
self.collection_uuid = collection_uuid

embedding_function_names = get_builtins()

self.embedding_function = (
embedding_function
if embedding_function in embedding_function_names
else "custom"
)


@dataclass
class CollectionAddEvent(TelemetryEvent):
name: ClassVar[str] = "collection_add"
max_batch_size: ClassVar[int] = 20
collection_uuid: str
add_amount: int
with_documents: int
with_metadata: int

def __init__(
self,
collection_uuid: str,
add_amount: int,
with_documents: int,
with_metadata: int,
batch_size: int = 1,
):
super().__init__()
self.collection_uuid = collection_uuid
self.add_amount = add_amount
self.with_documents = with_documents
self.with_metadata = with_metadata
self.batch_size = batch_size

@property
def batch_key(self) -> str:
return self.collection_uuid + self.name

def batch(self, other: "TelemetryEvent") -> "CollectionAddEvent":
if not self.batch_key == other.batch_key:
raise ValueError("Cannot batch events")
other = cast(CollectionAddEvent, other)
total_amount = self.add_amount + other.add_amount
return CollectionAddEvent(
collection_uuid=self.collection_uuid,
add_amount=total_amount,
with_documents=self.with_documents + other.with_documents,
with_metadata=self.with_metadata + other.with_metadata,
batch_size=self.batch_size + other.batch_size,
)


class CollectionUpdateEvent(TelemetryEvent):
collection_uuid: str
update_amount: int
with_embeddings: int
with_metadata: int
with_documents: int

def __init__(
self,
collection_uuid: str,
update_amount: int,
with_embeddings: int,
with_metadata: int,
with_documents: int,
):
super().__init__()
self.collection_uuid = collection_uuid
self.update_amount = update_amount
self.with_embeddings = with_embeddings
self.with_metadata = with_metadata
self.with_documents = with_documents


class CollectionQueryEvent(TelemetryEvent):
collection_uuid: str
query_amount: int
with_metadata_filter: bool
with_document_filter: bool
n_results: int
include_metadatas: bool
include_documents: bool
include_distances: bool

def __init__(
self,
collection_uuid: str,
query_amount: int,
with_metadata_filter: bool,
with_document_filter: bool,
n_results: int,
include_metadatas: bool,
include_documents: bool,
include_distances: bool,
):
super().__init__()
self.collection_uuid = collection_uuid
self.query_amount = query_amount
self.with_metadata_filter = with_metadata_filter
self.with_document_filter = with_document_filter
self.n_results = n_results
self.include_metadatas = include_metadatas
self.include_documents = include_documents
self.include_distances = include_distances


class CollectionGetEvent(TelemetryEvent):
collection_uuid: str
ids_count: int
limit: int
include_metadata: bool
include_documents: bool

def __init__(
self,
collection_uuid: str,
ids_count: int,
limit: int,
include_metadata: bool,
include_documents: bool,
):
super().__init__()
self.collection_uuid = collection_uuid
self.ids_count = ids_count
self.limit = limit
self.include_metadata = include_metadata
self.include_documents = include_documents


@dataclass
class CollectionDeleteEvent(TelemetryEvent):
name: ClassVar[str] = "collection_delete"
collection_uuid: str
delete_amount: int

def __init__(self, collection_uuid: str, delete_amount: int):
super().__init__()
self.collection_uuid = collection_uuid
self.delete_amount = delete_amount
20 changes: 20 additions & 0 deletions chromadb/telemetry/posthog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import posthog
import logging
import sys
from typing import Any, Dict, Set
from chromadb.config import System
from chromadb.telemetry import Telemetry, TelemetryEvent
from overrides import override
Expand All @@ -21,10 +22,29 @@ def __init__(self, system: System):
posthog_logger = logging.getLogger("posthog")
# Silence posthog's logging
posthog_logger.disabled = True

self.batched_events: Dict[str, TelemetryEvent] = {}
self.seen_event_types: Set[Any] = set()

super().__init__(system)

@override
def capture(self, event: TelemetryEvent) -> None:
if event.max_batch_size == 1 or event.batch_key not in self.seen_event_types:
self.seen_event_types.add(event.batch_key)
self._direct_capture(event)
return
batch_key = event.batch_key
if batch_key not in self.batched_events:
self.batched_events[batch_key] = event
return
batched_event = self.batched_events[batch_key].batch(event)
self.batched_events[batch_key] = batched_event
if batched_event.batch_size >= batched_event.max_batch_size:
self._direct_capture(batched_event)
del self.batched_events[batch_key]

def _direct_capture(self, event: TelemetryEvent) -> None:
try:
posthog.capture(
self.user_id,
Expand Down
Loading

0 comments on commit 4cb6cc4

Please sign in to comment.