Skip to content

Commit

Permalink
[CLN] Remove ChromaResponse in favor of endpoint-specific responses (#…
Browse files Browse the repository at this point in the history
…1767)

## Description of changes

*Summarize the changes made by this PR.*
 - Improvements & Bug fixes
- Remove `ChromaResponse` in favor of endpoint-specific response protos.
This conforms to best practices and will allow us to add response
information in the future.
 - New functionality
	 - ...

## Test plan
*How are these changes tested?*

- [ ] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
beggers authored Feb 24, 2024
1 parent 3742499 commit 3908b7b
Show file tree
Hide file tree
Showing 20 changed files with 1,561 additions and 1,066 deletions.
63 changes: 36 additions & 27 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
CreateCollectionRequest,
CreateCollectionResponse,
CreateDatabaseRequest,
CreateDatabaseResponse,
CreateSegmentRequest,
CreateSegmentResponse,
CreateTenantRequest,
CreateTenantResponse,
DeleteCollectionRequest,
DeleteCollectionResponse,
DeleteSegmentRequest,
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetDatabaseRequest,
Expand All @@ -28,8 +34,11 @@
GetSegmentsResponse,
GetTenantRequest,
GetTenantResponse,
ResetStateResponse,
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
Expand Down Expand Up @@ -85,22 +94,22 @@ def reset_state(self) -> None:
@overrides(check_signature=False)
def CreateDatabase(
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> CreateDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
return proto.ChromaResponse(
return CreateDatabaseResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database in self._tenants_to_databases_to_collections[tenant]:
return proto.ChromaResponse(
return CreateDatabaseResponse(
status=proto.Status(
code=409, reason=f"Database {database} already exists"
)
)
self._tenants_to_databases_to_collections[tenant][database] = {}
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
return proto.ChromaResponse(status=proto.Status(code=200))
return CreateDatabaseResponse(status=proto.Status(code=200))

@overrides(check_signature=False)
def GetDatabase(
Expand All @@ -124,16 +133,16 @@ def GetDatabase(

@overrides(check_signature=False)
def CreateTenant(
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
self, request: CreateTenantRequest, context: grpc.ServicerContext
) -> CreateTenantResponse:
tenant = request.name
if tenant in self._tenants_to_databases_to_collections:
return proto.ChromaResponse(
return CreateTenantResponse(
status=proto.Status(code=409, reason=f"Tenant {tenant} already exists")
)
self._tenants_to_databases_to_collections[tenant] = {}
self._tenants_to_database_to_id[tenant] = {}
return proto.ChromaResponse(status=proto.Status(code=200))
return CreateTenantResponse(status=proto.Status(code=200))

@overrides(check_signature=False)
def GetTenant(
Expand All @@ -155,29 +164,29 @@ def GetTenant(
@overrides(check_signature=False)
def CreateSegment(
self, request: CreateSegmentRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> CreateSegmentResponse:
segment = from_proto_segment(request.segment)
if segment["id"].hex in self._segments:
return proto.ChromaResponse(
return CreateSegmentResponse(
status=proto.Status(
code=409, reason=f"Segment {segment['id']} already exists"
)
)
self._segments[segment["id"].hex] = segment
return proto.ChromaResponse(
return CreateSegmentResponse(
status=proto.Status(code=200)
) # TODO: how are these codes used? Need to determine the standards for the code and reason.

@overrides(check_signature=False)
def DeleteSegment(
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> DeleteSegmentResponse:
id_to_delete = request.id
if id_to_delete in self._segments:
del self._segments[id_to_delete]
return proto.ChromaResponse(status=proto.Status(code=200))
return DeleteSegmentResponse(status=proto.Status(code=200))
else:
return proto.ChromaResponse(
return DeleteSegmentResponse(
status=proto.Status(
code=404, reason=f"Segment {id_to_delete} not found"
)
Expand Down Expand Up @@ -219,10 +228,10 @@ def GetSegments(
@overrides(check_signature=False)
def UpdateSegment(
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> UpdateSegmentResponse:
id_to_update = UUID(request.id)
if id_to_update.hex not in self._segments:
return proto.ChromaResponse(
return UpdateSegmentResponse(
status=proto.Status(
code=404, reason=f"Segment {id_to_update} not found"
)
Expand All @@ -244,7 +253,7 @@ def UpdateSegment(
self._merge_metadata(target, request.metadata)
if request.HasField("reset_metadata") and request.reset_metadata:
segment["metadata"] = {}
return proto.ChromaResponse(status=proto.Status(code=200))
return UpdateSegmentResponse(status=proto.Status(code=200))

@overrides(check_signature=False)
def CreateCollection(
Expand Down Expand Up @@ -331,24 +340,24 @@ def CreateCollection(
@overrides(check_signature=False)
def DeleteCollection(
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> DeleteCollectionResponse:
collection_id = request.id
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
return proto.ChromaResponse(
return DeleteCollectionResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database not in self._tenants_to_databases_to_collections[tenant]:
return proto.ChromaResponse(
return DeleteCollectionResponse(
status=proto.Status(code=404, reason=f"Database {database} not found")
)
collections = self._tenants_to_databases_to_collections[tenant][database]
if collection_id in collections:
del collections[collection_id]
return proto.ChromaResponse(status=proto.Status(code=200))
return DeleteCollectionResponse(status=proto.Status(code=200))
else:
return proto.ChromaResponse(
return DeleteCollectionResponse(
status=proto.Status(
code=404, reason=f"Collection {collection_id} not found"
)
Expand Down Expand Up @@ -392,7 +401,7 @@ def GetCollections(
@overrides(check_signature=False)
def UpdateCollection(
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> UpdateCollectionResponse:
id_to_update = UUID(request.id)
# Find the collection with this id
collections = {}
Expand All @@ -402,7 +411,7 @@ def UpdateCollection(
collections = maybe_collections

if id_to_update.hex not in collections:
return proto.ChromaResponse(
return UpdateCollectionResponse(
status=proto.Status(
code=404, reason=f"Collection {id_to_update} not found"
)
Expand Down Expand Up @@ -433,14 +442,14 @@ def UpdateCollection(
if request.reset_metadata:
collection["metadata"] = {}

return proto.ChromaResponse(status=proto.Status(code=200))
return UpdateCollectionResponse(status=proto.Status(code=200))

@overrides(check_signature=False)
def ResetState(
self, request: Empty, context: grpc.ServicerContext
) -> proto.ChromaResponse:
) -> ResetStateResponse:
self.reset_state()
return proto.ChromaResponse(status=proto.Status(code=200))
return ResetStateResponse(status=proto.Status(code=200))

def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
target_metadata = cast(Dict[str, Any], target)
Expand Down
Loading

0 comments on commit 3908b7b

Please sign in to comment.