Skip to content

Commit

Permalink
fix: Formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tazarov committed Oct 22, 2023
1 parent 63fd6df commit 4f9b59f
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 687 deletions.
16 changes: 6 additions & 10 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
"""Create a new collection with the given name and metadata.
Expand Down Expand Up @@ -88,8 +87,7 @@ def create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
"""Get a collection with the given name.
Args:
Expand All @@ -116,8 +114,7 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
"""Get or create a collection with the given name and metadata.
Args:
Expand Down Expand Up @@ -264,11 +261,11 @@ def _dimensions(self, collection_id: UUID) -> int:
specified by UUID.
Args:
collection_id: The UUID of the collection to get the dimensions of the
collection_id: The UUID of the collection to get the dimensions of the
embeddings in.
Returns:
int: The number of dimensions of the embeddings in the collection. If the
int: The number of dimensions of the embeddings in the collection. If the
collection has no embeddings, returns -1.
"""
Expand Down Expand Up @@ -351,8 +348,7 @@ def _query(
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings",
"metadatas", "documents", "distances"],
include: Include = ["embeddings", "metadatas", "documents", "distances"],
) -> QueryResult:
"""[Internal] Performs a nearest neighbors query on a collection specified by UUID.
Expand Down
33 changes: 11 additions & 22 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def resolve_url(
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
port = (
":" +
str(parsed.port or chroma_server_http_port) if not _skip_port else ""
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
)
path = parsed.path or default_api_path

Expand All @@ -89,8 +88,7 @@ def resolve_url(
if not path.endswith(default_api_path or ""):
path = path + default_api_path if default_api_path else ""
full_url = urlunparse(
(scheme, f"{net_loc}{port}", quote(
path.replace("//", "/")), "", "", "")
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
)

return full_url
Expand All @@ -106,8 +104,7 @@ def __init__(self, system: System):

self._api_url = FastAPI.resolve_url(
chroma_server_host=str(system.settings.chroma_server_host),
chroma_server_http_port=int(
str(system.settings.chroma_server_http_port)),
chroma_server_http_port=int(str(system.settings.chroma_server_http_port)),
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled,
default_api_path=system.settings.chroma_server_api_default_path,
)
Expand Down Expand Up @@ -164,16 +161,14 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
"""Creates a collection"""
resp = self._session.post(
self._api_url + "/collections",
data=json.dumps(
{"name": name, "metadata": metadata,
"get_or_create": get_or_create}
{"name": name, "metadata": metadata, "get_or_create": get_or_create}
),
)
raise_chroma_error(resp)
Expand All @@ -191,8 +186,7 @@ def create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
"""Returns a collection"""
resp = self._session.get(self._api_url + "/collections/" + name)
Expand All @@ -214,8 +208,7 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
return self.create_collection(
name, metadata, embedding_function, get_or_create=True
Expand All @@ -232,8 +225,7 @@ def _modify(
"""Updates a collection"""
resp = self._session.put(
self._api_url + "/collections/" + str(id),
data=json.dumps(
{"new_metadata": new_metadata, "new_name": new_name}),
data=json.dumps({"new_metadata": new_metadata, "new_name": new_name}),
)
raise_chroma_error(resp)

Expand All @@ -259,8 +251,7 @@ def _count(self, collection_id: UUID) -> int:
def _dimensions(self, collection_id: UUID) -> int:
"""Returns the dimensionality of the embeddings in the collection"""
resp = self._session.get(
self._api_url + "/collections/" +
str(collection_id) + "/dimensions"
self._api_url + "/collections/" + str(collection_id) + "/dimensions"
)
raise_chroma_error(resp)
return cast(int, resp.json())
Expand Down Expand Up @@ -376,8 +367,7 @@ def _add(
"""
batch = (ids, embeddings, metadatas, documents)
validate_batch(batch, {"max_batch_size": self.max_batch_size})
resp = self._submit_batch(
batch, "/collections/" + str(collection_id) + "/add")
resp = self._submit_batch(batch, "/collections/" + str(collection_id) + "/add")
raise_chroma_error(resp)
return True

Expand Down Expand Up @@ -503,8 +493,7 @@ def raise_chroma_error(resp: requests.Response) -> None:
body = resp.json()
if "error" in body:
if body["error"] in errors.error_types:
chroma_error = errors.error_types[body["error"]](
body["message"])
chroma_error = errors.error_types[body["error"]](body["message"])

except BaseException:
pass
Expand Down
12 changes: 4 additions & 8 deletions chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def __init__(
client: "API",
name: str,
id: UUID,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
metadata: Optional[CollectionMetadata] = None,
):
super().__init__(name=name, metadata=metadata, id=id)
Expand All @@ -66,8 +65,7 @@ def count(self) -> int:
int: The total number of embeddings added to the database
"""
return self._client._count(collection_id=self.id)\

return self._client._count(collection_id=self.id)

def dimensions(self) -> int:
"""The number of dimensions of the embeddings
Expand Down Expand Up @@ -199,8 +197,7 @@ def query(
else None
)
query_texts = (
maybe_cast_one_to_many(
query_texts) if query_texts is not None else None
maybe_cast_one_to_many(query_texts) if query_texts is not None else None
)
include = validate_include(include, allow_distances=True)
n_results = validate_n_results(n_results)
Expand Down Expand Up @@ -367,8 +364,7 @@ def _validate_embedding_set(
if metadatas is not None
else None
)
documents = maybe_cast_one_to_many(
documents) if documents is not None else None
documents = maybe_cast_one_to_many(documents) if documents is not None else None

# Check that one of embeddings or documents is provided
if require_embeddings_or_documents:
Expand Down
45 changes: 17 additions & 28 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,12 @@ def create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
get_or_create: bool = False,
) -> Collection:
if metadata is not None:
validate_metadata(metadata)

# TODO: remove backwards compatibility in naming requirements
check_index_name(name)

Expand Down Expand Up @@ -163,8 +162,7 @@ def get_or_create_collection(
self,
name: str,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
return self.create_collection( # type: ignore
name=name,
Expand All @@ -181,8 +179,7 @@ def get_or_create_collection(
def get_collection(
self,
name: str,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(
),
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
existing = self._sysdb.get_collections(name=name)

Expand Down Expand Up @@ -231,8 +228,7 @@ def _modify(
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
self._sysdb.update_collection(
id, name=new_name, metadata=new_metadata)
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
elif new_name:
self._sysdb.update_collection(id, name=new_name)
elif new_metadata:
Expand Down Expand Up @@ -389,8 +385,7 @@ def _get(
else None
)

metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)

if sort is not None:
raise NotImplementedError("Sorting is not yet supported")
Expand All @@ -410,8 +405,7 @@ def _get(
vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(
collection_id, VectorReader)
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vectors = vector_segment.get_vectors(ids=vector_ids)

# TODO: Fix type so we don't need to ignore
Expand Down Expand Up @@ -439,8 +433,9 @@ def _get(
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(
metadatas) if "metadatas" in include else None, # type: ignore
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
)

Expand Down Expand Up @@ -489,8 +484,7 @@ def _delete(
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)

if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(
collection_id, MetadataReader)
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where, where_document=where_document, ids=ids
)
Expand Down Expand Up @@ -525,7 +519,7 @@ def _count(self, collection_id: UUID) -> int:
@override
def _dimensions(self, collection_id: UUID) -> int:
coll = self._get_collection(collection_id)
return coll["dimension"] if coll["dimension"] is not None else -1
return cast(int, coll["dimension"]) if coll["dimension"] is not None else -1

@override
def _query(
Expand Down Expand Up @@ -557,8 +551,7 @@ def _query(
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)

metadata_reader = self._manager.get_segment(
collection_id, MetadataReader)
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)

if where or where_document:
records = metadata_reader.get_metadata(
Expand Down Expand Up @@ -588,8 +581,7 @@ def _query(
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"])
for r in result])
embeddings.append([cast(Embedding, r["embedding"]) for r in result])

if "documents" in include or "metadatas" in include:
all_ids: Set[str] = set()
Expand All @@ -607,11 +599,9 @@ def _query(
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(
id, None) for id in id_list]
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(
metadata_list)) # type: ignore
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore
Expand Down Expand Up @@ -677,8 +667,7 @@ def _validate_embedding_record(
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
if record["embedding"]:
self._validate_dimension(collection, len(
record["embedding"]), update=True)
self._validate_dimension(collection, len(record["embedding"]), update=True)

@trace_method("SegmentAPI._validate_dimension", OpenTelemetryGranularity.ALL)
def _validate_dimension(
Expand Down
12 changes: 4 additions & 8 deletions chromadb/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ def __init__(self, settings: Settings):
allow_methods=["*"],
)
if settings.chroma_server_auth_provider:
self._auth_middleware = self._api.require(
FastAPIChromaAuthMiddleware)
self._auth_middleware = self._api.require(FastAPIChromaAuthMiddleware)
self._app.add_middleware(
FastAPIChromaAuthMiddlewareWrapper,
auth_middleware=self._auth_middleware,
Expand All @@ -130,12 +129,9 @@ def __init__(self, settings: Settings):
self.router = ChromaAPIRouter()

self.router.add_api_route("/api/v1", self.root, methods=["GET"])
self.router.add_api_route(
"/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route(
"/api/v1/version", self.version, methods=["GET"])
self.router.add_api_route(
"/api/v1/heartbeat", self.heartbeat, methods=["GET"])
self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
self.router.add_api_route(
"/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
)
Expand Down
Loading

0 comments on commit 4f9b59f

Please sign in to comment.