Skip to content

Commit

Permalink
merge main
Browse files Browse the repository at this point in the history
  • Loading branch information
HammadB committed Oct 23, 2023
2 parents a4509e5 + bc69e08 commit 9e52cd5
Show file tree
Hide file tree
Showing 35 changed files with 808 additions and 217 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ index_data
# Default configuration for persist_directory in chromadb/config.py
# Currently it's located in "./chroma/"
chroma/
chroma_test_data
chroma_test_data/
server.htpasswd

.venv
Expand Down
8 changes: 5 additions & 3 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
"GetResult",
]


logger = logging.getLogger(__name__)

__settings = Settings()
Expand Down Expand Up @@ -77,8 +76,11 @@
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
else:
raise RuntimeError(
"\033[91mYour system has an unsupported version of sqlite3. Chroma requires sqlite3 >= 3.35.0.\033[0m\n"
"\033[94mPlease visit https://docs.trychroma.com/troubleshooting#sqlite to learn how to upgrade.\033[0m"
"\033[91mYour system has an unsupported version of sqlite3. Chroma \
requires sqlite3 >= 3.35.0.\033[0m\n"
"\033[94mPlease visit \
https://docs.trychroma.com/troubleshooting#sqlite to learn how \
to upgrade.\033[0m"
)


Expand Down
8 changes: 4 additions & 4 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
QueryResult,
)
from chromadb.config import Settings, System
from chromadb.telemetry import Telemetry
from chromadb.telemetry.events import ClientStartEvent
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.telemetry.product.events import ClientStartEvent
from chromadb.types import Database, Tenant, Where, WhereDocument
import chromadb.utils.embedding_functions as ef

Expand All @@ -43,7 +43,7 @@ def _create_system_if_not_exists(
new_system = System(settings)
cls._identifer_to_system[identifier] = new_system

new_system.instance(Telemetry)
new_system.instance(ProductTelemetryClient)
new_system.instance(ServerAPI)

new_system.start()
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(
self._server = self._system.instance(ServerAPI)

# Submit event for a client start
telemetry_client = self._system.instance(Telemetry)
telemetry_client = self._system.instance(ProductTelemetryClient)
telemetry_client.capture(ClientStartEvent())

@classmethod
Expand Down
38 changes: 35 additions & 3 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from chromadb.auth.providers import RequestsClientAuthProtocolAdapter
from chromadb.auth.registry import resolve_provider
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.telemetry import Telemetry
from chromadb.telemetry.opentelemetry import (
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.telemetry.product import ProductTelemetryClient
from urllib.parse import urlparse, urlunparse, quote

logger = logging.getLogger(__name__)
Expand All @@ -52,7 +57,8 @@ def _validate_host(host: str) -> None:
if "/" in host and (not host.startswith("http")):
raise ValueError(
"Invalid URL. "
"Seems that you are trying to pass URL as a host but without specifying the protocol. "
"Seems that you are trying to pass URL as a host but without \
specifying the protocol. "
"Please add http:// or https:// to the host."
)

Expand Down Expand Up @@ -93,7 +99,8 @@ def __init__(self, system: System):
system.settings.require("chroma_server_host")
system.settings.require("chroma_server_http_port")

self._telemetry_client = self.require(Telemetry)
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._settings = system.settings

self._api_url = FastAPI.resolve_url(
Expand Down Expand Up @@ -128,13 +135,15 @@ def __init__(self, system: System):
if self._header is not None:
self._session.headers.update(self._header)

@trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
def heartbeat(self) -> int:
"""Returns the current server time in nanoseconds to check if the server is alive"""
resp = self._session.get(self._api_url)
raise_chroma_error(resp)
return int(resp.json()["nanosecond heartbeat"])

@trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
def create_database(
self,
Expand All @@ -149,6 +158,7 @@ def create_database(
)
raise_chroma_error(resp)

@trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(
self,
Expand All @@ -166,6 +176,7 @@ def get_database(
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"]
)

@trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
resp = self._session.post(
Expand All @@ -174,6 +185,7 @@ def create_tenant(self, name: str) -> None:
)
raise_chroma_error(resp)

@trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> Tenant:
resp = self._session.get(
Expand All @@ -183,6 +195,7 @@ def get_tenant(self, name: str) -> Tenant:
resp_json = resp.json()
return Tenant(name=resp_json["name"])

@trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
@override
def list_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
Expand All @@ -200,6 +213,7 @@ def list_collections(

return collections

@trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
def create_collection(
self,
Expand Down Expand Up @@ -232,6 +246,7 @@ def create_collection(
metadata=resp_json["metadata"],
)

@trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@override
def get_collection(
self,
Expand All @@ -255,6 +270,9 @@ def get_collection(
metadata=resp_json["metadata"],
)

@trace_method(
"FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
)
@override
def get_or_create_collection(
self,
Expand All @@ -273,6 +291,7 @@ def get_or_create_collection(
database=database,
)

@trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION)
@override
def _modify(
self,
Expand All @@ -287,6 +306,7 @@ def _modify(
)
raise_chroma_error(resp)

@trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
def delete_collection(
self,
Expand All @@ -301,6 +321,7 @@ def delete_collection(
)
raise_chroma_error(resp)

@trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION)
@override
def _count(
self,
Expand All @@ -313,6 +334,7 @@ def _count(
raise_chroma_error(resp)
return cast(int, resp.json())

@trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
def _peek(
self,
Expand All @@ -325,6 +347,7 @@ def _peek(
include=["embeddings", "documents", "metadatas"],
)

@trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION)
@override
def _get(
self,
Expand Down Expand Up @@ -367,6 +390,7 @@ def _get(
documents=body.get("documents", None),
)

@trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION)
@override
def _delete(
self,
Expand All @@ -386,6 +410,7 @@ def _delete(
raise_chroma_error(resp)
return cast(IDs, resp.json())

@trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL)
def _submit_batch(
self,
batch: Tuple[
Expand All @@ -409,6 +434,7 @@ def _submit_batch(
)
return resp

@trace_method("FastAPI._add", OpenTelemetryGranularity.ALL)
@override
def _add(
self,
Expand All @@ -428,6 +454,7 @@ def _add(
raise_chroma_error(resp)
return True

@trace_method("FastAPI._update", OpenTelemetryGranularity.ALL)
@override
def _update(
self,
Expand All @@ -449,6 +476,7 @@ def _update(
resp.raise_for_status()
return True

@trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL)
@override
def _upsert(
self,
Expand All @@ -470,6 +498,7 @@ def _upsert(
resp.raise_for_status()
return True

@trace_method("FastAPI._query", OpenTelemetryGranularity.ALL)
@override
def _query(
self,
Expand Down Expand Up @@ -505,13 +534,15 @@ def _query(
documents=body.get("documents", None),
)

@trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL)
@override
def reset(self) -> bool:
"""Resets the database"""
resp = self._session.post(self._api_url + "/reset")
raise_chroma_error(resp)
return cast(bool, resp.json())

@trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
def get_version(self) -> str:
"""Returns the version of the server"""
Expand All @@ -525,6 +556,7 @@ def get_settings(self) -> Settings:
return self._settings

@property
@trace_method("FastAPI.max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
def max_batch_size(self) -> int:
if self._max_batch_size == -1:
Expand Down
Loading

0 comments on commit 9e52cd5

Please sign in to comment.