Skip to content

Commit

Permalink
feat(rate-limit/redis): Use redis to store throttling data for admin …
Browse files Browse the repository at this point in the history
…endpoints (#2863)
  • Loading branch information
gagantrivedi authored Jan 31, 2024
1 parent 31af594 commit 61537ce
Show file tree
Hide file tree
Showing 18 changed files with 434 additions and 45 deletions.
30 changes: 30 additions & 0 deletions api/app/settings/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,32 @@
CACHE_ENVIRONMENT_DOCUMENT_SECONDS = env.int("CACHE_ENVIRONMENT_DOCUMENT_SECONDS", 0)
ENVIRONMENT_DOCUMENT_CACHE_LOCATION = "environment-documents"

USER_THROTTLE_CACHE_NAME = "user-throttle"
USER_THROTTLE_CACHE_BACKEND = env.str(
"USER_THROTTLE_CACHE_BACKEND", "django.core.cache.backends.locmem.LocMemCache"
)
USER_THROTTLE_CACHE_LOCATION = env.str("USER_THROTTLE_CACHE_LOCATION", "admin-throttle")

# Using Redis for cache
# To use Redis for caching, set the cache backend to `django_redis.cache.RedisCache`.
# and set the cache location to the redis url
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#configure-as-cache-backend

# Set this to `core.redis_cluster.ClusterConnectionFactory` when using Redis Cluster.
DJANGO_REDIS_CONNECTION_FACTORY = env.str("DJANGO_REDIS_CONNECTION_FACTORY", "")

# Avoid raising exceptions if redis is down
# ref: https://github.com/jazzband/django-redis/tree/5.4.0#memcached-exceptions-behavior
DJANGO_REDIS_IGNORE_EXCEPTIONS = env.bool(
"DJANGO_REDIS_IGNORE_EXCEPTIONS", default=True
)

# Log exceptions generated by django-redis
# ref:https://github.com/jazzband/django-redis/tree/5.4.0#log-ignored-exceptions
DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS = env.bool(
"DJANGO_REDIS_LOG_IGNORED_EXCEPTIONS", True
)

CACHES = {
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
Expand Down Expand Up @@ -676,6 +702,10 @@
"LOCATION": ENVIRONMENT_SEGMENTS_CACHE_LOCATION,
"TIMEOUT": ENVIRONMENT_SEGMENTS_CACHE_SECONDS,
},
USER_THROTTLE_CACHE_NAME: {
"BACKEND": USER_THROTTLE_CACHE_BACKEND,
"LOCATION": USER_THROTTLE_CACHE_LOCATION,
},
}

TRENCH_AUTH = {
Expand Down
1 change: 1 addition & 0 deletions api/app/settings/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# We dont want to track tests
ENABLE_TELEMETRY = False
MAX_PROJECTS_IN_FREE_PLAN = 10
REST_FRAMEWORK["DEFAULT_THROTTLE_CLASSES"] = ["core.throttling.UserRateThrottle"]
REST_FRAMEWORK["DEFAULT_THROTTLE_RATES"] = {
"login": "100/min",
"mfa_code": "5/min",
Expand Down
2 changes: 2 additions & 0 deletions api/app_analytics/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SDKAnalyticsFlags(GenericAPIView):

permission_classes = (EnvironmentKeyPermissions,)
authentication_classes = (EnvironmentKeyAuthentication,)
throttle_classes = []

def get_serializer_class(self):
if getattr(self, "swagger_fake_view", False):
Expand Down Expand Up @@ -116,6 +117,7 @@ class SelfHostedTelemetryAPIView(CreateAPIView):

permission_classes = ()
authentication_classes = ()
throttle_classes = []
serializer_class = TelemetrySerializer


Expand Down
12 changes: 9 additions & 3 deletions api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import boto3
import pytest
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.core.cache import caches
from flag_engine.segments.constants import EQUAL
from moto import mock_dynamodb
from mypy_boto3_dynamodb.service_resource import DynamoDBServiceResource, Table
Expand Down Expand Up @@ -350,9 +350,15 @@ def reset_cache():
# https://groups.google.com/g/django-developers/c/zlaPsP13dUY
# TL;DR: Use this if your test interacts with cache since django
# does not clear cache after every test
cache.clear()
# Clear all caches before the test
for cache in caches.all():
cache.clear()

yield
cache.clear()

# Clear all caches after the test
for cache in caches.all():
cache.clear()


@pytest.fixture()
Expand Down
73 changes: 73 additions & 0 deletions api/core/redis_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Temporary module that adds support for Redis Cluster to django-redis by implementing
a connection factory class(`ClusterConnectionFactory`).
This module should be removed once [this](https://github.com/jazzband/django-redis/issues/606)
is resolved.
Usage:
------
Include the following configuration in Django project's settings.py file:
```python
# settings.py
DJANGO_REDIS_CONNECTION_FACTORY = "core.redis_cluster.ClusterConnectionFactory"
"""

import threading
from copy import deepcopy

from django.core.exceptions import ImproperlyConfigured
from django_redis.pool import ConnectionFactory
from redis.cluster import RedisCluster


class ClusterConnectionFactory(ConnectionFactory):
"""A connection factory for redis.cluster.RedisCluster
The cluster client manages connection pools internally, so we don't want to
do it at this level like the base ConnectionFactory does.
"""

# A global cache of URL->client so that within a process, we will reuse a
# single client, and therefore a single set of connection pools.
_clients = {}
_clients_lock = threading.Lock()

def connect(self, url: str) -> RedisCluster:
"""Given a connection url, return a client instance.
Prefer to return from our cache but if we don't yet have one build it
to populate the cache.
"""
if url not in self._clients:
with self._clients_lock:
if url not in self._clients:
params = self.make_connection_params(url)
self._clients[url] = self.get_connection(params)

return self._clients[url]

def get_connection(self, connection_params: dict) -> RedisCluster:
"""
Given connection_params, return a new client instance.
Basic django-redis ConnectionFactory manages a cache of connection
pools and builds a fresh client each time. because the cluster client
manages its own connection pools, we will instead merge the
"connection" and "client" kwargs and throw them all at the client to
sort out.
If we find conflicting client and connection kwargs, we'll raise an
error.
"""
client_cls_kwargs = deepcopy(self.redis_client_cls_kwargs)
# ... and smash 'em together (crashing if there's conflicts)...
for key, value in connection_params.items():
if key in client_cls_kwargs:
raise ImproperlyConfigured(
f"Found '{key}' in both the connection and the client kwargs"
)
client_cls_kwargs[key] = value

# ... and then build and return the client
return RedisCluster(**client_cls_kwargs)

def disconnect(self, connection: RedisCluster):
connection.disconnect_connection_pools()
7 changes: 7 additions & 0 deletions api/core/throttling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.conf import settings
from django.core.cache import caches
from rest_framework import throttling


class UserRateThrottle(throttling.UserRateThrottle):
cache = caches[settings.USER_THROTTLE_CACHE_NAME]
1 change: 1 addition & 0 deletions api/environments/identities/traits/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class SDKTraitsDeprecated(SDKAPIView):
# API to handle /api/v1/identities/<identifier>/traits/<trait_key> endpoints
# if Identity or Trait does not exist it will create one, otherwise will fetch existing
serializer_class = TraitSerializerBasic
throttle_classes = []

schema = None

Expand Down
1 change: 1 addition & 0 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class SDKIdentitiesDeprecated(SDKAPIView):
# if Identity does not exist it will create one, otherwise will fetch existing

serializer_class = IdentifyWithTraitsSerializer
throttle_classes = []

schema = None

Expand Down
2 changes: 1 addition & 1 deletion api/features/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,8 @@ class SDKFeatureStates(GenericAPIView):
permission_classes = (EnvironmentKeyPermissions,)
authentication_classes = (EnvironmentKeyAuthentication,)
renderer_classes = [JSONRenderer]
throttle_classes = []
pagination_class = None
throttle_classes = []

@swagger_auto_schema(
query_serializer=SDKFeatureStatesQuerySerializer(),
Expand Down
4 changes: 2 additions & 2 deletions api/integrations/sentry/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from django.conf import settings

NON_FUNCTIONAL_ENDPOINTS = ("/health", "")
SDK_ENDPOINTS = (
SDK_ENDPOINTS = {
"/api/v1/flags",
"/api/v1/identities",
"/api/v1/traits",
"/api/v1/traits/bulk",
"/api/v1/environment-document",
"/api/v1/analytics/flags",
)
}


def traces_sampler(ctx):
Expand Down
40 changes: 38 additions & 2 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pydantic = "~1.10.9"
pyngo = "~1.6.0"
flagsmith = "^3.4.0"
python-gnupg = "^0.5.1"
django-redis = "^5.4.0"

[tool.poetry.group.auth-controller]
optional = true
Expand Down
12 changes: 12 additions & 0 deletions api/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ def sdk_client(environment_api_key):
return client


@pytest.fixture()
def server_side_sdk_client(
admin_client: APIClient, environment: int, environment_api_key: str
) -> APIClient:
url = reverse("api-v1:environments:api-keys-list", args={environment_api_key})
response = admin_client.post(url, data={"name": "Some key"})

client = APIClient()
client.credentials(HTTP_X_ENVIRONMENT_KEY=response.json()["key"])
return client


@pytest.fixture()
def default_feature_value():
return "default_value"
Expand Down
Loading

0 comments on commit 61537ce

Please sign in to comment.