diff --git a/README.rst b/README.rst index 0b00c5a..7be1c2b 100644 --- a/README.rst +++ b/README.rst @@ -190,6 +190,28 @@ for convenience use *payload* argument to mock out json response. Example below. # will throw an exception. +**aioresponses allows to use callbacks to provide dynamic responses** + +.. code:: python + + import asyncio + import aiohttp + from aioresponses import CallbackResult, aioresponses + + def callback(url, **kwargs): + return CallbackResult(status=418) + + @aioresponses() + def test_callback(m, test_client): + loop = asyncio.get_event_loop() + session = ClientSession() + m.get('http://example.com', callback=callback) + + resp = loop.run_until_complete(session.get('http://example.com')) + + assert resp.status == 418 + + **aioresponses can be used in a pytest fixture** .. code:: python @@ -203,8 +225,6 @@ for convenience use *payload* argument to mock out json response. Example below. yield m - - Features -------- * Easy to mock out HTTP requests made by *aiohttp.ClientSession* diff --git a/aioresponses/__init__.py b/aioresponses/__init__.py index 0888a38..81f0787 100644 --- a/aioresponses/__init__.py +++ b/aioresponses/__init__.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- -from .core import aioresponses +from .core import CallbackResult, aioresponses __version__ = '0.4.1' __all__ = [ + 'CallbackResult', 'aioresponses', ] diff --git a/aioresponses/compat.py b/aioresponses/compat.py index dfd4e11..7b967a3 100644 --- a/aioresponses/compat.py +++ b/aioresponses/compat.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -import asyncio +import asyncio # noqa import re from distutils.version import StrictVersion -from typing import Dict, Optional, Union # noqa +from typing import Dict, Optional, Tuple, Union # noqa from urllib.parse import parse_qsl, urlencode -from aiohttp import StreamReader -from aiohttp import __version__ as aiohttp_version +from aiohttp import __version__ as aiohttp_version, StreamReader from multidict import MultiDict from yarl import URL @@ -27,7 +26,6 @@ def stream_reader_factory( protocol = ResponseHandler(loop=loop) return StreamReader(protocol) - else: # pragma: no cover def stream_reader_factory(): diff --git a/aioresponses/core.py b/aioresponses/core.py index 2f9e835..8e502be 100644 --- a/aioresponses/core.py +++ b/aioresponses/core.py @@ -4,7 +4,7 @@ from collections import namedtuple from distutils.version import StrictVersion from functools import wraps -from typing import Dict, Tuple, Union, Optional, List # noqa +from typing import Callable, Dict, Tuple, Union, Optional, List # noqa from unittest.mock import Mock, patch from aiohttp import ( @@ -18,15 +18,35 @@ from multidict import CIMultiDict from .compat import ( + AIOHTTP_VERSION, URL, Pattern, stream_reader_factory, merge_params, normalize_url, - AIOHTTP_VERSION ) +class CallbackResult: + + def __init__(self, method: str = hdrs.METH_GET, + status: int = 200, + body: str = '', + content_type: str = 'application/json', + payload: Dict = None, + headers: Dict = None, + response_class: 'ClientResponse' = None, + reason: Optional[str] = None): + self.method = method + self.status = status + self.body = body + self.content_type = content_type + self.payload = payload + self.headers = headers + self.response_class = response_class + self.reason = reason + + class RequestMatch(object): url_or_pattern = None # type: Union[URL, Pattern] @@ -41,7 +61,8 @@ def __init__(self, url: Union[str, Pattern], response_class: 'ClientResponse' = None, timeout: bool = False, repeat: bool = False, - reason: Optional[str] = None): + reason: Optional[str] = None, + callback: Optional[Callable] = None): if isinstance(url, Pattern): self.url_or_pattern = url self.match_func = self.match_regexp @@ -50,17 +71,14 @@ def __init__(self, url: Union[str, Pattern], self.match_func = self.match_str self.method = method.lower() self.status = status - if payload is not None: - body = json.dumps(payload) - if not isinstance(body, bytes): - body = str.encode(body) self.body = body + self.payload = payload self.exception = exception if timeout: self.exception = asyncio.TimeoutError('Connection timeout test') self.headers = headers self.content_type = content_type - self.response_class = response_class or ClientResponse + self.response_class = response_class self.repeat = repeat self.reason = reason if self.reason is None: @@ -68,6 +86,7 @@ def __init__(self, url: Union[str, Pattern], self.reason = http.RESPONSES[self.status][0] except (IndexError, KeyError): self.reason = '' + self.callback = callback def match_str(self, url: URL) -> bool: return self.url_or_pattern == url @@ -80,11 +99,32 @@ def match(self, method: str, url: URL) -> bool: return False return self.match_func(url) - async def build_response( - self, url: URL - ) -> 'Union[ClientResponse, Exception]': - if isinstance(self.exception, Exception): - return self.exception + def _build_raw_headers(self, headers: Dict) -> Tuple: + """ + Convert a dict of headers to a tuple of tuples + + Mimics the format of ClientResponse. + """ + raw_headers = [] + for k, v in headers.items(): + raw_headers.append((k.encode('utf8'), v.encode('utf8'))) + return tuple(raw_headers) + + def _build_response(self, url: 'Union[URL, str]', + method: str = hdrs.METH_GET, + status: int = 200, + body: str = '', + content_type: str = 'application/json', + payload: Dict = None, + headers: Dict = None, + response_class: 'ClientResponse' = None, + reason: Optional[str] = None) -> ClientResponse: + if response_class is None: + response_class = ClientResponse + if payload is not None: + body = json.dumps(payload) + if not isinstance(body, bytes): + body = str.encode(body) kwargs = {} if AIOHTTP_VERSION >= StrictVersion('3.1.0'): loop = Mock() @@ -94,43 +134,52 @@ async def build_response( kwargs['writer'] = Mock() kwargs['continue100'] = None kwargs['timer'] = TimerNoop() - if AIOHTTP_VERSION >= StrictVersion('3.3.0'): - pass - else: + if AIOHTTP_VERSION < StrictVersion('3.3.0'): kwargs['auto_decompress'] = True kwargs['traces'] = [] kwargs['loop'] = loop kwargs['session'] = None - resp = self.response_class(self.method, url, **kwargs) - # we need to initialize headers manually - headers = CIMultiDict({hdrs.CONTENT_TYPE: self.content_type}) - if self.headers: - headers.update(self.headers) - raw_headers = self._build_raw_headers(headers) + # We need to initialize headers manually + _headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type}) + if headers: + _headers.update(headers) + raw_headers = self._build_raw_headers(_headers) + resp = response_class(method, url, **kwargs) if AIOHTTP_VERSION >= StrictVersion('3.3.0'): # Reified attributes - resp._headers = headers + resp._headers = _headers resp._raw_headers = raw_headers else: - resp.headers = headers + resp.headers = _headers resp.raw_headers = raw_headers - resp.status = self.status - resp.reason = self.reason + resp.status = status + resp.reason = reason resp.content = stream_reader_factory() - resp.content.feed_data(self.body) + resp.content.feed_data(body) resp.content.feed_eof() return resp - def _build_raw_headers(self, headers: Dict) -> Tuple: - """ - Convert a dict of headers to a tuple of tuples - - Mimics the format of ClientResponse. - """ - raw_headers = [] - for k, v in headers.items(): - raw_headers.append((k.encode('utf8'), v.encode('utf8'))) - return tuple(raw_headers) + async def build_response( + self, url: URL, **kwargs: Dict + ) -> 'Union[ClientResponse, Exception]': + if isinstance(self.exception, Exception): + return self.exception + if callable(self.callback): + result = self.callback(url, **kwargs) + else: + result = None + result = self if result is None else result + resp = self._build_response( + url=url, + method=result.method, + status=result.status, + body=result.body, + content_type=result.content_type, + payload=result.payload, + headers=result.headers, + response_class=result.response_class, + reason=result.reason) + return resp RequestCall = namedtuple('RequestCall', ['args', 'kwargs']) @@ -226,7 +275,8 @@ def add(self, url: 'Union[URL, str]', method: str = hdrs.METH_GET, response_class: 'ClientResponse' = None, repeat: bool = False, timeout: bool = False, - reason: Optional[str] = None) -> None: + reason: Optional[str] = None, + callback: Optional[Callable] = None) -> None: self._matches.append(RequestMatch( url, method=method, @@ -240,12 +290,15 @@ def add(self, url: 'Union[URL, str]', method: str = hdrs.METH_GET, repeat=repeat, timeout=timeout, reason=reason, + callback=callback, )) - async def match(self, method: str, url: URL) -> Optional['ClientResponse']: + async def match( + self, method: str, url: URL, **kwargs: Dict + ) -> Optional['ClientResponse']: for i, matcher in enumerate(self._matches): if matcher.match(method, url): - response = await matcher.build_response(url) + response = await matcher.build_response(url, **kwargs) break else: return None @@ -269,7 +322,7 @@ async def _request_mock(self, orig_self: ClientSession, orig_self, method, url, *args, **kwargs )) - response = await self.match(method, url) + response = await self.match(method, url, **kwargs) if response is None: raise ClientConnectionError( 'Connection refused: {} {}'.format(method, url) diff --git a/tests/test_aioresponses.py b/tests/test_aioresponses.py index 3e51a84..5652faf 100644 --- a/tests/test_aioresponses.py +++ b/tests/test_aioresponses.py @@ -26,7 +26,7 @@ from aiohttp.http_exceptions import HttpProcessingError from aioresponses.compat import URL -from aioresponses import aioresponses +from aioresponses import CallbackResult, aioresponses @ddt @@ -314,3 +314,17 @@ def test_timeout(self, mocked): with self.assertRaises(asyncio.TimeoutError): self.run_async(self.request(self.url)) + + @aioresponses() + def test_callback(self, m): + body = b'New body' + + def callback(url, **kwargs): + self.assertEqual(str(url), self.url) + self.assertEqual(kwargs, {'allow_redirects': True}) + return CallbackResult(body=body) + + m.get(self.url, callback=callback) + response = self.run_async(self.request(self.url)) + data = self.run_async(response.read()) + assert data == body diff --git a/tox.ini b/tox.ini index c62fc08..277f966 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,8 @@ envlist = py35-aiohttp{20,21,22,23,30,31,32,33,34,35} py36-aiohttp-master py36-aiohttp{20,21,22,23,30,31,32,33,34,35} + py37-aiohttp-master + py37-aiohttp{20,21,22,23,30,31,32,33,34,35} skipsdist = True [testenv:flake8] @@ -38,6 +40,7 @@ deps = basepython = py35: python3.5 py36: python3.6 + py37: python3.7 commands = python setup.py test