diff --git a/django_ratelimit/decorators.py b/django_ratelimit/decorators.py index 40c9541..017b775 100644 --- a/django_ratelimit/decorators.py +++ b/django_ratelimit/decorators.py @@ -1,4 +1,5 @@ from functools import wraps +from inspect import iscoroutinefunction from django.conf import settings from django.utils.module_loading import import_string @@ -13,6 +14,23 @@ def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): def decorator(fn): + # if iscoroutinefunction(fn): + # @wraps(fn) + # async def _async_wrapped(request, *args, **kw): + # old_limited = getattr(request, 'limited', False) + # ratelimited = is_ratelimited( + # request=request, group=group, fn=fn, key=key, rate=rate, + # method=method, increment=True) + # request.limited = ratelimited or old_limited + # if ratelimited and block: + # cls = getattr( + # settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + # if isinstance(cls, str): + # cls = import_string(cls) + # raise cls() + # return await fn(request, *args, **kw) + # return _async_wrapped + @wraps(fn) def _wrapped(request, *args, **kw): old_limited = getattr(request, 'limited', False) @@ -23,7 +41,9 @@ def _wrapped(request, *args, **kw): if ratelimited and block: cls = getattr( settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) - raise (import_string(cls) if isinstance(cls, str) else cls)() + if isinstance(cls, str): + cls = import_string(cls) + raise cls() return fn(request, *args, **kw) return _wrapped return decorator diff --git a/django_ratelimit/tests.py b/django_ratelimit/tests.py index a58c89e..518b071 100644 --- a/django_ratelimit/tests.py +++ b/django_ratelimit/tests.py @@ -1,3 +1,4 @@ +import asyncio from functools import partial from django.core.cache import cache, InvalidCacheBackendError @@ -411,6 +412,24 @@ def view(request): req.META['REMOTE_ADDR'] = '2001:db9::1000' assert not view(req) + def test_decorate_async_function(self): + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + await asyncio.sleep(0) + return request.limited + + req1 = rf.get('/') + req1.META['REMOTE_ADDR'] = '1.2.3.4' + + req2 = rf.get('/') + req2.META['REMOTE_ADDR'] = '1.2.3.4' + + assert event_loop.run_until_complete(view(req1)) is False + assert event_loop.run_until_complete(view(req2)) is True + class FunctionsTests(TestCase): def setUp(self):