From 668b95866e221ac58de5577cb056d280f880f10e Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Wed, 20 Sep 2023 11:01:44 +0300 Subject: [PATCH] feat: Use decorators for cleaner code --- chromadb/api/segment.py | 121 +++++++++++---------------------- chromadb/telemetry/__init__.py | 37 +++++++++- 2 files changed, 75 insertions(+), 83 deletions(-) diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e62a37e1f04..faabd0c4944 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -28,15 +28,7 @@ validate_where_document, validate_batch, ) -from chromadb.telemetry.events import ( - CollectionAddEvent, - CollectionDeleteEvent, - CollectionGetEvent, - CollectionUpdateEvent, - CollectionQueryEvent, - ClientCreateCollectionEvent, -) - +from chromadb.telemetry import telemetry_class_decorator import chromadb.types as t from typing import Optional, Sequence, Generator, List, cast, Set, Dict @@ -71,6 +63,7 @@ def check_index_name(index_name: str) -> None: raise ValueError(msg) +@telemetry_class_decorator() class SegmentAPI(API): """API implementation utilizing the new segment-based internal architecture""" @@ -79,6 +72,7 @@ class SegmentAPI(API): _manager: SegmentManager _producer: Producer # TODO: fire telemetry events + # TODO there is probably a better way to handle this than polluting the code here _telemetry_client: Telemetry _tenant_id: str _topic_ns: str @@ -89,7 +83,7 @@ def __init__(self, system: System): self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) - self._telemetry_client = self.require(Telemetry) + self._telemetry_client = self.require(Telemetry) # TODO see above self._producer = self.require(Producer) self._tenant_id = system.settings.tenant_id self._topic_ns = system.settings.topic_namespace @@ -107,7 +101,8 @@ def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), get_or_create: bool = False, ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -119,7 +114,8 @@ def create_collection( if get_or_create: if metadata and existing[0]["metadata"] != metadata: self._modify(id=existing[0]["id"], new_metadata=metadata) - existing = self._sysdb.get_collections(id=existing[0]["id"]) + existing = self._sysdb.get_collections( + id=existing[0]["id"]) return Collection( client=self, id=existing[0]["id"], @@ -143,13 +139,6 @@ def create_collection( for segment in segments: self._sysdb.create_segment(segment) - self._telemetry_client.capture( - ClientCreateCollectionEvent( - collection_uuid=str(id), - embedding_function=embedding_function.__class__.__name__, - ) - ) - return Collection( client=self, id=id, @@ -163,7 +152,8 @@ def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), ) -> Collection: return self.create_collection( name=name, @@ -179,7 +169,8 @@ def get_or_create_collection( def get_collection( self, name: str, - embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), + embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction( + ), ) -> Collection: existing = self._sysdb.get_collections(name=name) @@ -226,7 +217,8 @@ def _modify( # TODO eventually we'll want to use OptionalArgument and Unspecified in the # signature of `_modify` but not changing the API right now. if new_name and new_metadata: - self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) + self._sysdb.update_collection( + id, name=new_name, metadata=new_metadata) elif new_name: self._sysdb.update_collection(id, name=new_name) elif new_metadata: @@ -273,14 +265,6 @@ def _add( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( - CollectionAddEvent( - collection_uuid=str(collection_id), - add_amount=len(ids), - with_metadata=len(ids) if metadatas is not None else 0, - with_documents=len(ids) if documents is not None else 0, - ) - ) return True @override @@ -310,16 +294,6 @@ def _update( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( - CollectionUpdateEvent( - collection_uuid=str(collection_id), - update_amount=len(ids), - with_embeddings=len(embeddings) if embeddings else 0, - with_metadata=len(metadatas) if metadatas else 0, - with_documents=len(documents) if documents else 0, - ) - ) - return True @override @@ -365,14 +339,16 @@ def _get( where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: - where = validate_where(where) if where is not None and len(where) > 0 else None + where = validate_where(where) if where is not None and len( + where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 else None ) - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) if sort is not None: raise NotImplementedError("Sorting is not yet supported") @@ -392,7 +368,8 @@ def _get( vectors: Sequence[t.VectorEmbeddingRecord] = [] if "embeddings" in include: vector_ids = [r["id"] for r in records] - vector_segment = self._manager.get_segment(collection_id, VectorReader) + vector_segment = self._manager.get_segment( + collection_id, VectorReader) vectors = vector_segment.get_vectors(ids=vector_ids) # TODO: Fix type so we don't need to ignore @@ -404,22 +381,13 @@ def _get( if "documents" in include: documents = [_doc(m) for m in metadatas] - self._telemetry_client.capture( - CollectionGetEvent( - collection_uuid=str(collection_id), - ids_count=len(ids) if ids else 0, - limit=limit if limit else 0, - include_metadata="metadatas" in include, - include_documents="documents" in include, - ) - ) - return GetResult( ids=[r["id"] for r in records], embeddings=[r["embedding"] for r in vectors] if "embeddings" in include else None, - metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore + metadatas=_clean_metadatas( + metadatas) if "metadatas" in include else None, # type: ignore documents=documents if "documents" in include else None, # type: ignore ) @@ -431,7 +399,8 @@ def _delete( where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: - where = validate_where(where) if where is not None and len(where) > 0 else None + where = validate_where(where) if where is not None and len( + where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 @@ -460,7 +429,8 @@ def _delete( self._manager.hint_use_collection(collection_id, t.Operation.DELETE) if (where or where_document) or not ids: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) records = metadata_segment.get_metadata( where=where, where_document=where_document, ids=ids ) @@ -477,16 +447,12 @@ def _delete( records_to_submit.append(r) self._producer.submit_embeddings(coll["topic"], records_to_submit) - self._telemetry_client.capture( - CollectionDeleteEvent( - collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) - ) - ) return ids_to_delete @override def _count(self, collection_id: UUID) -> int: - metadata_segment = self._manager.get_segment(collection_id, MetadataReader) + metadata_segment = self._manager.get_segment( + collection_id, MetadataReader) return metadata_segment.count() @override @@ -499,7 +465,8 @@ def _query( where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: - where = validate_where(where) if where is not None and len(where) > 0 else where + where = validate_where(where) if where is not None and len( + where) > 0 else where where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 @@ -512,7 +479,8 @@ def _query( for embedding in query_embeddings: self._validate_dimension(coll, len(embedding), update=False) - metadata_reader = self._manager.get_segment(collection_id, MetadataReader) + metadata_reader = self._manager.get_segment( + collection_id, MetadataReader) if where or where_document: records = metadata_reader.get_metadata( @@ -542,7 +510,8 @@ def _query( if "distances" in include: distances.append([r["distance"] for r in result]) if "embeddings" in include: - embeddings.append([cast(Embedding, r["embedding"]) for r in result]) + embeddings.append([cast(Embedding, r["embedding"]) + for r in result]) if "documents" in include or "metadatas" in include: all_ids: Set[str] = set() @@ -560,26 +529,15 @@ def _query( # queries the metadata segment. The metadata segment does not have # the record. In this case we choose to return potentially # incorrect data in the form of None. - metadata_list = [metadata_by_id.get(id, None) for id in id_list] + metadata_list = [metadata_by_id.get( + id, None) for id in id_list] if "metadatas" in include: - metadatas.append(_clean_metadatas(metadata_list)) # type: ignore + metadatas.append(_clean_metadatas( + metadata_list)) # type: ignore if "documents" in include: doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore - self._telemetry_client.capture( - CollectionQueryEvent( - collection_uuid=str(collection_id), - query_amount=len(query_embeddings), - n_results=n_results, - with_metadata_filter=where is not None, - with_document_filter=where_document is not None, - include_metadatas="metadatas" in include, - include_documents="documents" in include, - include_distances="distances" in include, - ) - ) - return QueryResult( ids=ids, distances=distances if distances else None, @@ -624,7 +582,8 @@ def _validate_embedding_record( ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" if record["embedding"]: - self._validate_dimension(collection, len(record["embedding"]), update=True) + self._validate_dimension(collection, len( + record["embedding"]), update=True) def _validate_dimension( self, collection: t.Collection, dim: int, update: bool diff --git a/chromadb/telemetry/__init__.py b/chromadb/telemetry/__init__.py index d20b8e5d71c..2c7f1001867 100644 --- a/chromadb/telemetry/__init__.py +++ b/chromadb/telemetry/__init__.py @@ -1,6 +1,14 @@ +from chromadb.telemetry.events import ( + CollectionAddEvent, + CollectionDeleteEvent, + CollectionGetEvent, + CollectionUpdateEvent, + CollectionQueryEvent, + ClientCreateCollectionEvent, +) from abc import abstractmethod import os -from typing import Callable, ClassVar, Dict, Any +from typing import Callable, ClassVar, Dict, Any, Tuple, Type import uuid import time from threading import Event, Thread @@ -84,7 +92,8 @@ def capture(self, event: TelemetryEvent) -> None: def schedule_event_function( self, event_function: Callable[..., TelemetryEvent], every_seconds: int ) -> None: - RepeatedTelemetry(every_seconds, lambda: self.capture(event_function())) + RepeatedTelemetry( + every_seconds, lambda: self.capture(event_function())) @property def context(self) -> Dict[str, Any]: @@ -120,3 +129,27 @@ def user_id(self) -> str: except Exception: self._curr_user_id = self.UNKNOWN_USER_ID return self._curr_user_id + + +METHOD_EVENT_MAP: Dict[str, Callable[..., Any]] = { + "create_collection": ClientCreateCollectionEvent, + "get_or_create_collection": ClientCreateCollectionEvent, + "get_collection": CollectionGetEvent, +} + + +def telemetry_class_decorator() -> Callable[[type], Type[Any]]: + def _decorator(cls: type) -> Type[Any]: + for target_method in METHOD_EVENT_MAP.keys(): + original_method = getattr(cls, target_method) + + def wrapped(self: Any, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) \ + -> Any: + + self.telemetry_client.capture( + METHOD_EVENT_MAP[target_method](*args, **kwargs)) + return original_method(self, *args, **kwargs) + + setattr(cls, target_method, wrapped) + return cls + return _decorator