Skip to content

Commit

Permalink
Use dependency injection StrictRedis object (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
abersheeran authored May 7, 2022
1 parent 159e704 commit 6498223
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 46 deletions.
22 changes: 7 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ The following example will limit users under the `"default"` group to access `/t
from typing import Tuple

from ratelimit import RateLimitMiddleware, Rule
from ratelimit.auths import EmptyInformation
from ratelimit.backends.redis import RedisBackend
from ratelimit.backends.simple import MemoryBackend

# Simple rate-limiter in memory:
from ratelimit.backends.simple import MemoryBackend

rate_limit = RateLimitMiddleware(
ASGI_APP,
Expand All @@ -45,34 +43,28 @@ rate_limit = RateLimitMiddleware(
)

# with Redis:
from redis.asyncio import StrictRedis
from ratelimit.backends.redis import RedisBackend

rate_limit = RateLimitMiddleware(
ASGI_APP,
AUTH_FUNCTION,
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"^/towns": [Rule(second=1, group="default"), Rule(group="admin")],
r"^/forests": [Rule(minute=1, group="default"), Rule(group="admin")],
},
)

# Or if using Starlette, FastApi, or index.py framework
app.add_middleware(
RateLimitMiddleware,
authenticate=AUTH_FUNCTION,
backend=RedisBackend(),
config={
r"^/towns": [Rule(second=1, group="default"), Rule(group="admin")],
r"^/forests": [Rule(minute=1, group="default"), Rule(group="admin")],
},
)
```

:warning: **The pattern's order is important, rules are set on the first match**: Be careful here !

Next, provide a custom authenticate function, or use one of the [existing auth methods](#built-in-auth-functions).

```python
from ratelimit.auths import EmptyInformation


async def AUTH_FUNCTION(scope: Scope) -> Tuple[str, str]:
"""
Resolve the user's unique identifier and the user's group from ASGI SCOPE.
Expand Down
13 changes: 2 additions & 11 deletions ratelimit/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,8 @@


class RedisBackend(BaseBackend):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str = None,
ssl: bool = False,
) -> None:
self._redis = StrictRedis(
host=host, port=port, db=db, password=password, ssl=ssl
)
def __init__(self, redis: StrictRedis) -> None:
self._redis = redis
self.lua_script: Script = self._redis.register_script(SCRIPT)

async def set_block_time(self, user: str, block_time: int) -> None:
Expand Down
13 changes: 2 additions & 11 deletions ratelimit/backends/slidingredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,8 @@


class SlidingRedisBackend(BaseBackend):
def __init__(
self,
host: str = "localhost",
port: int = 6379,
db: int = 0,
password: str = None,
ssl: bool = False,
) -> None:
self._redis = StrictRedis(
host=host, port=port, db=db, password=password, ssl=ssl
)
def __init__(self, redis: StrictRedis) -> None:
self._redis = redis
self.sliding_function = self._redis.register_script(SLIDING_WINDOW_SCRIPT)

async def get_limits(self, path: str, user: str, rule: Rule) -> dict:
Expand Down
6 changes: 3 additions & 3 deletions tests/backends/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_redis(redis_backend):
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
redis_backend(),
redis_backend(StrictRedis()),
{
r"/second_limit": [Rule(second=1), Rule(group="admin")],
r"/minute.*": [Rule(minute=1), Rule(group="admin")],
Expand All @@ -70,7 +70,7 @@ async def test_multiple(redis_backend):
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
redis_backend(),
redis_backend(StrictRedis()),
{r"/multiple": [Rule(second=1, minute=3)]},
)
async with httpx.AsyncClient(
Expand Down Expand Up @@ -115,7 +115,7 @@ async def test_multiple_with_punitive(redis_backend):
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
redis_backend(),
redis_backend(StrictRedis()),
{r"/multiple": [Rule(second=1, minute=3)]},
)
async with httpx.AsyncClient(
Expand Down
13 changes: 7 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import httpx
import pytest
from redis.asyncio import StrictRedis

from ratelimit import RateLimitMiddleware, Rule
from ratelimit.auths import EmptyInformation
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_invalid_init_config():
RateLimitMiddleware(
hello_world,
auth_func,
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"??.*": [Rule(group="admin")],
},
Expand All @@ -61,7 +62,7 @@ def test_invalid_init_config():
RateLimitMiddleware(
hello_world,
"123",
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"/test": [Rule(group="admin")],
},
Expand All @@ -84,7 +85,7 @@ async def test_on_auth_error_default():
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"/": [Rule(group="admin")],
},
Expand All @@ -111,7 +112,7 @@ async def test_on_auth_error_with_handler():
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"/": [Rule(group="admin")],
},
Expand Down Expand Up @@ -149,7 +150,7 @@ async def test_custom_blocked():
rate_limit = RateLimitMiddleware(
hello_world,
authenticate=auth_func,
backend=RedisBackend(),
backend=RedisBackend(StrictRedis()),
config={r"/": [Rule(second=1), Rule(group="admin")]},
on_blocked=yourself_429,
)
Expand All @@ -171,7 +172,7 @@ async def test_rule_zone():
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
RedisBackend(),
RedisBackend(StrictRedis()),
{
r"/message": [Rule(second=1, zone="common")],
r"/\d+": [Rule(second=1, zone="common")],
Expand Down

0 comments on commit 6498223

Please sign in to comment.