Skip to content

Commit

Permalink
feat: Use decorators for cleaner code
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Sep 20, 2023
1 parent 80c6230 commit 668b958
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 83 deletions.
121 changes: 40 additions & 81 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,7 @@
validate_where_document,
validate_batch,
)
from chromadb.telemetry.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)

from chromadb.telemetry import telemetry_class_decorator
import chromadb.types as t

from typing import Optional, Sequence, Generator, List, cast, Set, Dict
Expand Down Expand Up @@ -71,6 +63,7 @@ def check_index_name(index_name: str) -> None:
raise ValueError(msg)


@telemetry_class_decorator()
class SegmentAPI(API):
"""API implementation utilizing the new segment-based internal architecture"""

Expand All @@ -79,6 +72,7 @@ class SegmentAPI(API):
_manager: SegmentManager
_producer: Producer
# TODO: fire telemetry events
# TODO there is probably a better way to handle this than polluting the code here
_telemetry_client: Telemetry
_tenant_id: str
_topic_ns: str
Expand All @@ -89,7 +83,7 @@ def __init__(self, system: System):
self._settings = system.settings
self._sysdb = self.require(SysDB)
self._manager = self.require(SegmentManager)
self._telemetry_client = self.require(Telemetry)
self._telemetry_client = self.require(Telemetry) # TODO see above
self._producer = self.require(Producer)
self._tenant_id = system.settings.tenant_id
self._topic_ns = system.settings.topic_namespace
Expand All @@ -107,7 +101,8 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
get_or_create: bool = False,
) -> Collection:
existing = self._sysdb.get_collections(name=name)
Expand All @@ -119,7 +114,8 @@ def create_collection(
if get_or_create:
if metadata and existing[0]["metadata"] != metadata:
self._modify(id=existing[0]["id"], new_metadata=metadata)
existing = self._sysdb.get_collections(id=existing[0]["id"])
existing = self._sysdb.get_collections(
id=existing[0]["id"])
return Collection(
client=self,
id=existing[0]["id"],
Expand All @@ -143,13 +139,6 @@ def create_collection(
for segment in segments:
self._sysdb.create_segment(segment)

self._telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
embedding_function=embedding_function.__class__.__name__,
)
)

return Collection(
client=self,
id=id,
Expand All @@ -163,7 +152,8 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
return self.create_collection(
name=name,
Expand All @@ -179,7 +169,8 @@ def get_or_create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
) -> Collection:
existing = self._sysdb.get_collections(name=name)

Expand Down Expand Up @@ -226,7 +217,8 @@ def _modify(
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
self._sysdb.update_collection(
id, name=new_name, metadata=new_metadata)
elif new_name:
self._sysdb.update_collection(id, name=new_name)
elif new_metadata:
Expand Down Expand Up @@ -273,14 +265,6 @@ def _add(
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
)
)
return True

@override
Expand Down Expand Up @@ -310,16 +294,6 @@ def _update(
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(
CollectionUpdateEvent(
collection_uuid=str(collection_id),
update_amount=len(ids),
with_embeddings=len(embeddings) if embeddings else 0,
with_metadata=len(metadatas) if metadatas else 0,
with_documents=len(documents) if documents else 0,
)
)

return True

@override
Expand Down Expand Up @@ -365,14 +339,16 @@ def _get(
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"],
) -> GetResult:
where = validate_where(where) if where is not None and len(where) > 0 else None
where = validate_where(where) if where is not None and len(
where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)

metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)

if sort is not None:
raise NotImplementedError("Sorting is not yet supported")
Expand All @@ -392,7 +368,8 @@ def _get(
vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vector_segment = self._manager.get_segment(
collection_id, VectorReader)
vectors = vector_segment.get_vectors(ids=vector_ids)

# TODO: Fix type so we don't need to ignore
Expand All @@ -404,22 +381,13 @@ def _get(
if "documents" in include:
documents = [_doc(m) for m in metadatas]

self._telemetry_client.capture(
CollectionGetEvent(
collection_uuid=str(collection_id),
ids_count=len(ids) if ids else 0,
limit=limit if limit else 0,
include_metadata="metadatas" in include,
include_documents="documents" in include,
)
)

return GetResult(
ids=[r["id"] for r in records],
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore
metadatas=_clean_metadatas(
metadatas) if "metadatas" in include else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
)

Expand All @@ -431,7 +399,8 @@ def _delete(
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> IDs:
where = validate_where(where) if where is not None and len(where) > 0 else None
where = validate_where(where) if where is not None and len(
where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
Expand Down Expand Up @@ -460,7 +429,8 @@ def _delete(
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where, where_document=where_document, ids=ids
)
Expand All @@ -477,16 +447,12 @@ def _delete(
records_to_submit.append(r)
self._producer.submit_embeddings(coll["topic"], records_to_submit)

self._telemetry_client.capture(
CollectionDeleteEvent(
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
)
)
return ids_to_delete

@override
def _count(self, collection_id: UUID) -> int:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
return metadata_segment.count()

@override
Expand All @@ -499,7 +465,8 @@ def _query(
where_document: WhereDocument = {},
include: Include = ["documents", "metadatas", "distances"],
) -> QueryResult:
where = validate_where(where) if where is not None and len(where) > 0 else where
where = validate_where(where) if where is not None and len(
where) > 0 else where
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
Expand All @@ -512,7 +479,8 @@ def _query(
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
metadata_reader = self._manager.get_segment(
collection_id, MetadataReader)

if where or where_document:
records = metadata_reader.get_metadata(
Expand Down Expand Up @@ -542,7 +510,8 @@ def _query(
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"]) for r in result])
embeddings.append([cast(Embedding, r["embedding"])
for r in result])

if "documents" in include or "metadatas" in include:
all_ids: Set[str] = set()
Expand All @@ -560,26 +529,15 @@ def _query(
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
metadata_list = [metadata_by_id.get(
id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
metadatas.append(_clean_metadatas(
metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore

self._telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=len(query_embeddings),
n_results=n_results,
with_metadata_filter=where is not None,
with_document_filter=where_document is not None,
include_metadatas="metadatas" in include,
include_documents="documents" in include,
include_distances="distances" in include,
)
)

return QueryResult(
ids=ids,
distances=distances if distances else None,
Expand Down Expand Up @@ -624,7 +582,8 @@ def _validate_embedding_record(
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
if record["embedding"]:
self._validate_dimension(collection, len(record["embedding"]), update=True)
self._validate_dimension(collection, len(
record["embedding"]), update=True)

def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
Expand Down
37 changes: 35 additions & 2 deletions chromadb/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from chromadb.telemetry.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)
from abc import abstractmethod
import os
from typing import Callable, ClassVar, Dict, Any
from typing import Callable, ClassVar, Dict, Any, Tuple, Type
import uuid
import time
from threading import Event, Thread
Expand Down Expand Up @@ -84,7 +92,8 @@ def capture(self, event: TelemetryEvent) -> None:
def schedule_event_function(
self, event_function: Callable[..., TelemetryEvent], every_seconds: int
) -> None:
RepeatedTelemetry(every_seconds, lambda: self.capture(event_function()))
RepeatedTelemetry(
every_seconds, lambda: self.capture(event_function()))

@property
def context(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -120,3 +129,27 @@ def user_id(self) -> str:
except Exception:
self._curr_user_id = self.UNKNOWN_USER_ID
return self._curr_user_id


METHOD_EVENT_MAP: Dict[str, Callable[..., Any]] = {
"create_collection": ClientCreateCollectionEvent,
"get_or_create_collection": ClientCreateCollectionEvent,
"get_collection": CollectionGetEvent,
}


def telemetry_class_decorator() -> Callable[[type], Type[Any]]:
def _decorator(cls: type) -> Type[Any]:
for target_method in METHOD_EVENT_MAP.keys():
original_method = getattr(cls, target_method)

def wrapped(self: Any, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) \
-> Any:

self.telemetry_client.capture(
METHOD_EVENT_MAP[target_method](*args, **kwargs))
return original_method(self, *args, **kwargs)

setattr(cls, target_method, wrapped)
return cls
return _decorator

0 comments on commit 668b958

Please sign in to comment.