-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Description of changes *Summarize the changes made by this PR.* - New functionality - Add rate limiting service. If no rate limit service is provided, it will not do anything. ## Test plan *How are these changes tested?* Unit test on rate limiting service. --------- Co-authored-by: nicolas <[email protected]>
- Loading branch information
1 parent
44e8ff7
commit 5e9c7a7
Showing
5 changed files
with
139 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import inspect | ||
from abc import abstractmethod | ||
from functools import wraps | ||
from typing import Optional, Any, Dict, Callable, cast | ||
|
||
from chromadb.config import Component | ||
from chromadb.quota import QuotaProvider, Resource | ||
|
||
|
||
class RateLimitError(Exception): | ||
def __init__(self, resource: Resource, quota: int): | ||
super().__init__(f"rate limit error. resource: {resource} quota: {quota}") | ||
self.quota = quota | ||
self.resource = resource | ||
|
||
class RateLimitingProvider(Component): | ||
@abstractmethod | ||
def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool: | ||
""" | ||
Determines if a request identified by `key` can proceed given the current rate limit. | ||
:param key: The identifier for the requestor (unused in this simplified implementation). | ||
:param quota: The quota which will be used for bucket size. | ||
:param point: The number of tokens required to fulfill the request. | ||
:return: True if the request can proceed, False otherwise. | ||
""" | ||
pass | ||
|
||
|
||
def rate_limit( | ||
subject: str, | ||
resource: str | ||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: | ||
def decorator(f: Callable[..., Any]) -> Callable[..., Any]: | ||
args_name = inspect.getfullargspec(f)[0] | ||
if subject not in args_name: | ||
raise Exception(f'rate_limit decorator have unknown subject "{subject}", available {args_name}') | ||
key_index = args_name.index(subject) | ||
|
||
@wraps(f) | ||
def wrapper(self, *args: Any, **kwargs: Dict[Any, Any]) -> Any: | ||
# If not rate limiting provider is present, just run and return the function. | ||
if self._system.settings.chroma_rate_limiting_provider_impl is None: | ||
return f(self, *args, **kwargs) | ||
|
||
if subject in kwargs: | ||
subject_value = kwargs[subject] | ||
else: | ||
if len(args) < key_index: | ||
return f(self, *args, **kwargs) | ||
subject_value = args[key_index-1] | ||
key_value = resource + "-" + subject_value | ||
self._system.settings.chroma_rate_limiting_provider_impl | ||
quota_provider = self._system.require(QuotaProvider) | ||
rate_limiter = self._system.require(RateLimitingProvider) | ||
quota = quota_provider.get_for_subject(resource=resource,subject=subject) | ||
is_allowed = rate_limiter.is_allowed(key_value, quota) | ||
if is_allowed is False: | ||
raise RateLimitError(resource=resource, quota=quota) | ||
return f(self, *args, **kwargs) | ||
return wrapper | ||
|
||
return decorator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import Optional, Dict | ||
|
||
from overrides import overrides | ||
|
||
from chromadb.config import System | ||
from chromadb.rate_limiting import RateLimitingProvider | ||
|
||
|
||
class RateLimitingTestProvider(RateLimitingProvider): | ||
def __init__(self, system: System): | ||
super().__init__(system) | ||
|
||
@overrides | ||
def is_allowed(self, key: str, quota: int, point: Optional[int] = 1) -> bool: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Optional | ||
from unittest.mock import patch | ||
|
||
from chromadb.config import System, Settings, Component | ||
from chromadb.quota import QuotaEnforcer, Resource | ||
import pytest | ||
|
||
from chromadb.rate_limiting import rate_limit | ||
|
||
|
||
class RateLimitingGym(Component): | ||
def __init__(self, system: System): | ||
super().__init__(system) | ||
self.system = system | ||
|
||
@rate_limit(subject="bar", resource="FAKE_RESOURCE") | ||
def bench(self, foo: str, bar: str) -> str: | ||
return foo | ||
|
||
def mock_get_for_subject(self, resource: Resource, subject: Optional[str] = "", tier: Optional[str] = "") -> Optional[ | ||
int]: | ||
"""Mock function to simulate quota retrieval.""" | ||
return 10 | ||
|
||
@pytest.fixture(scope="module") | ||
def rate_limiting_gym() -> QuotaEnforcer: | ||
settings = Settings( | ||
chroma_quota_provider_impl="chromadb.quota.test_provider.QuotaProviderForTest", | ||
chroma_rate_limiting_provider_impl="chromadb.rate_limiting.test_provider.RateLimitingTestProvider" | ||
) | ||
system = System(settings) | ||
return RateLimitingGym(system) | ||
|
||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: False) | ||
def test_rate_limiting_should_raise(rate_limiting_gym: RateLimitingGym): | ||
with pytest.raises(Exception) as exc_info: | ||
rate_limiting_gym.bench("foo", "bar") | ||
assert "FAKE_RESOURCE" in str(exc_info.value.resource) | ||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True) | ||
def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym): | ||
assert rate_limiting_gym.bench(foo="foo", bar="bar") is "foo" | ||
|
||
@patch('chromadb.quota.test_provider.QuotaProviderForTest.get_for_subject', mock_get_for_subject) | ||
@patch('chromadb.rate_limiting.test_provider.RateLimitingTestProvider.is_allowed', lambda self, key, quota, point=1: True) | ||
def test_rate_limiting_should_not_raise(rate_limiting_gym: RateLimitingGym): | ||
assert rate_limiting_gym.bench("foo", "bar") is "foo" |