Skip to content

Commit

Permalink
Feature: Add method to Rule (#67)
Browse files Browse the repository at this point in the history
* add method to rule, add test and update readme

* add method rate limiting support to redis backends, update tests

* add support for by client rate limiting to memorybackend

* add method docs

* fix formatting, reduce down to one search for matching rule

* revert blocking change to block user

* fix error in test, update remove rule test to include http method

* update method in remove rule to default
  • Loading branch information
pa-t authored Oct 25, 2022
1 parent 986aaa3 commit 30dc417
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 5 deletions.
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ async def AUTH_FUNCTION(scope: Scope) -> Tuple[str, str]:
rate_limit = RateLimitMiddleware(ASGI_APP, AUTH_FUNCTION, ...)
```

The `Rule` type takes a time unit (e.g. `"second"`) and/or a `"group"`, as a param. If the `"group"` param is not specified then the `"authenticate"` method needs to return the "default group".
The `Rule` type takes a time unit (e.g. `"second"`), a `"group"`, and a `"method"` as a param. If the `"group"` param is not specified then the `"authenticate"` method needs to return the "default group". The `"method"` param corresponds to the http method, if it is not specified, the rule will be applied to all http requests.

Example:
```python
...
config={
r"^/towns": [Rule(second=1), Rule(second=10, group="admin")],
r"^/towns": [Rule(second=1, method="get"), Rule(second=10, group="admin")],
}
...

Expand Down Expand Up @@ -147,6 +147,19 @@ When the user's request frequency triggers the upper limit, all requests in the

Example: `Rule(second=5, block_time=60)`, this rule will limit the user to a maximum of 5 visits per second. Once this limit is exceeded, all requests within the next 60 seconds will return `429`.


### HTTP Method

If you want a rate limit a specifc HTTP method on an endpoint, the `Rule` object has a `method` param. If no method is specified, the default value is `"*"` for all HTTP methods.

```python
r"^/towns": [
Rule(group="admin", method="get", second=10),
Rule(group="admin", method="post", second=2)
]
```


### Custom block handler

Just specify `on_blocked` and you can customize the asgi application that is called when blocked.
Expand Down
8 changes: 7 additions & 1 deletion ratelimit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
raise exc

# Select the first rule that can be matched
match_rule = list(filter(lambda r: r.group == group, rules))
method = scope["method"].lower()
match_rule = list(
filter(
lambda r: r.group == group and r.method.lower() in [method, "*"],
rules,
)
)
if match_rule:
rule = match_rule[0]
break
Expand Down
3 changes: 2 additions & 1 deletion ratelimit/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@dataclass
class Rule:
group: str = "default"
method: str = "*"

second: Optional[int] = None
minute: Optional[int] = None
Expand All @@ -22,7 +23,7 @@ def ruleset(self, path: str, user: str) -> Dict[str, Tuple[int, int]]:
the redis keys and values is a tuple of (limit, ttl)
"""
return {
f"{path}:{user}:{name}": (limit, TTL[name])
f"{path}:{self.method}:{user}:{name}": (limit, TTL[name])
for name, limit in map(lambda name: (name, getattr(self, name)), RULENAMES)
if limit is not None
}
Expand Down
2 changes: 1 addition & 1 deletion tests/backends/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def test_other(memory_backend):
assert response.status_code == 429

assert rate_limit.backend.remove_user("user")
assert rate_limit.backend.remove_rule(path, f"{path}:user:second")
assert rate_limit.backend.remove_rule(path, f"{path}:*:user:second")

response = await client.get(path)
assert response.status_code == 200
51 changes: 51 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,54 @@ async def test_rule_zone():
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 429


@pytest.mark.asyncio
async def test_rule_method():
rate_limit = RateLimitMiddleware(
hello_world,
auth_func,
RedisBackend(StrictRedis()),
{
r"/message": [Rule(minute=1, method="get"), Rule(minute=2, method="post")],
r"/towns": [Rule(minute=1)],
},
)
async with httpx.AsyncClient(
app=rate_limit, base_url="http://testserver"
) as client: # type: httpx.AsyncClient
# /message is only limited on get, post should be fine
response = await client.get(
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 200
response = await client.get(
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 429
response = await client.post(
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 200
response = await client.post(
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 200
response = await client.post(
"/message", headers={"user": "user", "group": "default"}
)
assert response.status_code == 429

# /towns is limited on all methods
response = await client.get(
"/towns", headers={"user": "user", "group": "default"}
)
assert response.status_code == 200
response = await client.get(
"/towns", headers={"user": "user", "group": "default"}
)
assert response.status_code == 429
response = await client.post(
"/towns", headers={"user": "user", "group": "default"}
)
assert response.status_code == 429

0 comments on commit 30dc417

Please sign in to comment.