From 5e9c7a704dbfed76ab33b89c49141c914c0f6e95 Mon Sep 17 00:00:00 2001 From: nicolasgere Date: Wed, 28 Feb 2024 13:19:28 -0800 Subject: [PATCH] [ENH]: add rate limiting (#1728) ## Description of changes *Summarize the changes made by this PR.* - New functionality - Add rate limiting service. If no rate limit service is provided, it will not do anything. ## Test plan *How are these changes tested?* Unit test on rate limiting service. --------- Co-authored-by: nicolas --- chromadb/config.py | 4 +- chromadb/rate_limiting/__init__.py | 63 +++++++++++++++++++ chromadb/rate_limiting/test_provider.py | 15 +++++ chromadb/server/fastapi/__init__.py | 9 +++ .../test/rate_limiting/test_rate_limiting.py | 50 +++++++++++++++ 5 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 chromadb/rate_limiting/__init__.py create mode 100644 chromadb/rate_limiting/test_provider.py create mode 100644 chromadb/test/rate_limiting/test_rate_limiting.py diff --git a/chromadb/config.py b/chromadb/config.py index b4a78d5746c..bc8234bc34d 100644 --- a/chromadb/config.py +++ b/chromadb/config.py @@ -76,7 +76,7 @@ "chromadb.segment.SegmentManager": "chroma_segment_manager_impl", "chromadb.segment.distributed.SegmentDirectory": "chroma_segment_directory_impl", "chromadb.segment.distributed.MemberlistProvider": "chroma_memberlist_provider_impl", - + "chromadb.rate_limiting.RateLimitingProvider": "chroma_rate_limiting_provider_impl" } DEFAULT_TENANT = "default_tenant" @@ -102,7 +102,7 @@ class Settings(BaseSettings): # type: ignore "chromadb.segment.impl.manager.local.LocalSegmentManager" ) chroma_quota_provider_impl:Optional[str] = None - + chroma_rate_limiting_provider_impl:Optional[str] = None # Distributed architecture specific components chroma_segment_directory_impl: str = "chromadb.segment.impl.distributed.segment_directory.RendezvousHashSegmentDirectory" chroma_memberlist_provider_impl: str = "chromadb.segment.impl.distributed.segment_directory.CustomResourceMemberlistProvider" diff --git a/chromadb/rate_limiting/__init__.py b/chromadb/rate_limiting/__init__.py new file mode 100644 index 00000000000..fb1a955ff12 --- /dev/null +++ b/chromadb/rate_limiting/__init__.py @@ -0,0 +1,63 @@ +import inspect +from abc import abstractmethod +from functools import wraps +from typing import Optional, Any, Dict, Callable, cast + +from chromadb.config import Component +from chromadb.quota import QuotaProvider, Resource + + +class RateLimitError(Exception): + def __init__(self, resource: Resource, quota: int): + super().__init__(f"rate limit error. resource: {resource} quota: {quota}") + self.quota = quota + self.resource = resource + +class RateLimitingProvider(Component): + @abstractmethod + def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool: + """ + Determines if a request identified by `key` can proceed given the current rate limit. + + :param key: The identifier for the requestor (unused in this simplified implementation). + :param quota: The quota which will be used for bucket size. + :param point: The number of tokens required to fulfill the request. + :return: True if the request can proceed, False otherwise. + """ + pass + + +def rate_limit( + subject: str, + resource: str +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + args_name = inspect.getfullargspec(f)[0] + if subject not in args_name: + raise Exception(f'rate_limit decorator have unknown subject "{subject}", available {args_name}') + key_index = args_name.index(subject) + + @wraps(f) + def wrapper(self, *args: Any, **kwargs: Dict[Any, Any]) -> Any: + # If not rate limiting provider is present, just run and return the function. + if self._system.settings.chroma_rate_limiting_provider_impl is None: + return f(self, *args, **kwargs) + + if subject in kwargs: + subject_value = kwargs[subject] + else: + if len(args) < key_index: + return f(self, *args, **kwargs) + subject_value = args[key_index-1] + key_value = resource + "-" + subject_value + self._system.settings.chroma_rate_limiting_provider_impl + quota_provider = self._system.require(QuotaProvider) + rate_limiter = self._system.require(RateLimitingProvider) + quota = quota_provider.get_for_subject(resource=resource,subject=subject) + is_allowed = rate_limiter.is_allowed(key_value, quota) + if is_allowed is False: + raise RateLimitError(resource=resource, quota=quota) + return f(self, *args, **kwargs) + return wrapper + + return decorator \ No newline at end of file diff --git a/chromadb/rate_limiting/test_provider.py b/chromadb/rate_limiting/test_provider.py new file mode 100644 index 00000000000..6b97db3dad3 --- /dev/null +++ b/chromadb/rate_limiting/test_provider.py @@ -0,0 +1,15 @@ +from typing import Optional, Dict + +from overrides import overrides + +from chromadb.config import System +from chromadb.rate_limiting import RateLimitingProvider + + +class RateLimitingTestProvider(RateLimitingProvider): + def __init__(self, system: System): + super().__init__(system) + + @overrides + def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool: + pass diff --git a/chromadb/server/fastapi/__init__.py b/chromadb/server/fastapi/__init__.py index 5eccc54d819..292f5038dea 100644 --- a/chromadb/server/fastapi/__init__.py +++ b/chromadb/server/fastapi/__init__.py @@ -36,6 +36,7 @@ InvalidHTTPVersion, ) from chromadb.quota import QuotaError +from chromadb.rate_limiting import RateLimitError from chromadb.server.fastapi.types import ( AddEmbedding, CreateDatabase, @@ -142,6 +143,7 @@ def __init__(self, settings: Settings): allow_methods=["*"], ) self._app.add_exception_handler(QuotaError, self.quota_exception_handler) + self._app.add_exception_handler(RateLimitError, self.rate_limit_exception_handler) self._app.on_event("shutdown")(self.shutdown) @@ -290,6 +292,13 @@ def shutdown(self) -> None: def app(self) -> fastapi.FastAPI: return self._app + async def rate_limit_exception_handler(self, request: Request, exc: RateLimitError): + return JSONResponse( + status_code=429, + content={"message": f"rate limit. resource: {exc.resource} quota: {exc.quota}"}, + ) + + def root(self) -> Dict[str, int]: return {"nanosecond heartbeat": self._api.heartbeat()} diff --git a/chromadb/test/rate_limiting/test_rate_limiting.py b/chromadb/test/rate_limiting/test_rate_limiting.py new file mode 100644 index 00000000000..9f7e8c677e7 --- /dev/null +++ b/chromadb/test/rate_limiting/test_rate_limiting.py @@ -0,0 +1,50 @@ +from typing import Optional +from unittest.mock import patch + +from chromadb.config import System, Settings, Component +from chromadb.quota import QuotaEnforcer, Resource +import pytest + +from chromadb.rate_limiting import rate_limit + + +class RateLimitingGym(Component): + def __init__(self, system: System): + super().__init__(system) + self.system = system + + @rate_limit(subject="bar", resource="FAKE_RESOURCE") + def bench(self, foo: str, bar: str) -> str: + return foo + +def mock_get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[ + int]: + """Mock function to simulate quota retrieval.""" + return 10 + +@pytest.fixture(scope="module") +def rate_limiting_gym() -> QuotaEnforcer: + settings = Settings( + chroma_quota_provider_impl="chromadb.quota.test_provider.QuotaProviderForTest", + chroma_rate_limiting_provider_impl="chromadb.rate_limiting.test_provider.RateLimitingTestProvider" + ) + system = System(settings) + return RateLimitingGym(system) + + +@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) +@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: False) +def test_rate_limiting_should_raise(rate_limiting_gym: RateLimitingGym): + with pytest.raises(Exception) as exc_info: + rate_limiting_gym.bench("foo", "bar") + assert "FAKE_RESOURCE" in str(exc_info.value.resource) + +@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) +@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True) +def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym): + assert rate_limiting_gym.bench(foo="foo", bar="bar") is "foo" + +@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) +@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True) +def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym): + assert rate_limiting_gym.bench("foo", "bar") is "foo" \ No newline at end of file