From b2677ab6908e641d23d78654545ff4c12c3dc488 Mon Sep 17 00:00:00 2001 From: James Socol Date: Mon, 24 Jul 2023 15:47:37 -0400 Subject: [PATCH] Add support for async functions to decorator Tries to add support for async functions to the decorator, but trips over not having a failing test to fix. --- .github/actions/test/action.yml | 22 ++++++++++++++++------ django_ratelimit/decorators.py | 27 ++++++++++++++++++++++++++- django_ratelimit/tests.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index a9bbd60..9f07876 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -14,13 +14,23 @@ runs: with: python-version: ${{ inputs.python-version }} - - name: Install dependencies + - name: Update pip shell: sh - run: | - python -m pip install --upgrade pip - if [[ ${{ inputs.django-version }} != 'main' ]]; then pip install --pre -q "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99"; fi - if [[ ${{ inputs.django-version }} == 'main' ]]; then pip install https://github.com/django/django/archive/main.tar.gz; fi - pip install flake8 django-redis pymemcache + run: python -m pip install --upgrade pip + + - name: Install Django + shell: sh + run: python -m pip install "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99" + if: ${{ inputs.django-version != 'main' }} + + - name: Install Django main + shell: sh + run: python -m pip install https://github.com/django/django/archive/main.tar.gz + if: ${{ inputs.django-version == 'main' }} + + - name: Install Django dependencies + shell: sh + run: pip install flake8 django-redis pymemcache - name: Test shell: sh diff --git a/django_ratelimit/decorators.py b/django_ratelimit/decorators.py index 40c9541..0d50cea 100644 --- a/django_ratelimit/decorators.py +++ b/django_ratelimit/decorators.py @@ -1,4 +1,10 @@ from functools import wraps +import django +if django.VERSION >= (4, 1): + from asgiref.sync import iscoroutinefunction +else: + def iscoroutinefunction(func): + return False from django.conf import settings from django.utils.module_loading import import_string @@ -13,6 +19,23 @@ def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): def decorator(fn): + # if iscoroutinefunction(fn): + # @wraps(fn) + # async def _async_wrapped(request, *args, **kw): + # old_limited = getattr(request, 'limited', False) + # ratelimited = is_ratelimited( + # request=request, group=group, fn=fn, key=key, rate=rate, + # method=method, increment=True) + # request.limited = ratelimited or old_limited + # if ratelimited and block: + # cls = getattr( + # settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + # if isinstance(cls, str): + # cls = import_string(cls) + # raise cls() + # return await fn(request, *args, **kw) + # return _async_wrapped + @wraps(fn) def _wrapped(request, *args, **kw): old_limited = getattr(request, 'limited', False) @@ -23,7 +46,9 @@ def _wrapped(request, *args, **kw): if ratelimited and block: cls = getattr( settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) - raise (import_string(cls) if isinstance(cls, str) else cls)() + if isinstance(cls, str): + cls = import_string(cls) + raise cls() return fn(request, *args, **kw) return _wrapped return decorator diff --git a/django_ratelimit/tests.py b/django_ratelimit/tests.py index a58c89e..4561a1e 100644 --- a/django_ratelimit/tests.py +++ b/django_ratelimit/tests.py @@ -1,4 +1,8 @@ +import asyncio + +import django from functools import partial +from unittest import skipIf from django.core.cache import cache, InvalidCacheBackendError from django.core.exceptions import ImproperlyConfigured @@ -12,7 +16,10 @@ from django_ratelimit.core import (get_usage, is_ratelimited, _split_rate, _get_ip) - +if django.VERSION >= (4, 1): + from asgiref.sync import iscoroutinefunction + from django.test import AsyncRequestFactory + arf = AsyncRequestFactory() rf = RequestFactory() @@ -411,6 +418,26 @@ def view(request): req.META['REMOTE_ADDR'] = '2001:db9::1000' assert not view(req) + @skipIf( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", + ) + async def test_decorate_async_function(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + await asyncio.sleep(0) + return request.limited + + req1 = arf.get('/') + req1.META['REMOTE_ADDR'] = '1.2.3.4' + + req2 = arf.get('/') + req2.META['REMOTE_ADDR'] = '1.2.3.4' + + assert iscoroutinefunction(view) + assert await view(req1) is False + assert await view(req2) is True + class FunctionsTests(TestCase): def setUp(self):