diff --git a/README.rst b/README.rst index a921239..7be1c2b 100644 --- a/README.rst +++ b/README.rst @@ -199,7 +199,7 @@ for convenience use *payload* argument to mock out json response. Example below. from aioresponses import CallbackResult, aioresponses def callback(url, **kwargs): - return CallbackResult(url, status=418) + return CallbackResult(status=418) @aioresponses() def test_callback(m, test_client): diff --git a/aioresponses/__init__.py b/aioresponses/__init__.py index 73a1bf6..81f0787 100644 --- a/aioresponses/__init__.py +++ b/aioresponses/__init__.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from .compat import CallbackResult -from .core import aioresponses +from .core import CallbackResult, aioresponses __version__ = '0.4.1' __all__ = [ - 'aioresponses', 'CallbackResult', + 'aioresponses', ] diff --git a/aioresponses/compat.py b/aioresponses/compat.py index 201bfd0..7b967a3 100644 --- a/aioresponses/compat.py +++ b/aioresponses/compat.py @@ -1,20 +1,12 @@ # -*- coding: utf-8 -*- import asyncio # noqa -import json import re from distutils.version import StrictVersion from typing import Dict, Optional, Tuple, Union # noqa -from unittest.mock import Mock from urllib.parse import parse_qsl, urlencode -from aiohttp import ( - __version__ as aiohttp_version, - ClientResponse, - StreamReader, - hdrs, -) -from aiohttp.helpers import TimerNoop -from multidict import CIMultiDict, MultiDict +from aiohttp import __version__ as aiohttp_version, StreamReader +from multidict import MultiDict from yarl import URL try: @@ -55,90 +47,6 @@ def normalize_url(url: 'Union[URL, str]') -> 'URL': return url.with_query(urlencode(sorted(parse_qsl(url.query_string)))) -def _build_raw_headers(headers: Dict) -> Tuple: - """ - Convert a dict of headers to a tuple of tuples - - Mimics the format of ClientResponse. - """ - raw_headers = [] - for k, v in headers.items(): - raw_headers.append((k.encode('utf8'), v.encode('utf8'))) - return tuple(raw_headers) - - -def build_response( - url: 'Union[URL, str]', - method: str = hdrs.METH_GET, - status: int = 200, - body: str = '', - content_type: str = 'application/json', - payload: Dict = None, - headers: Dict = None, - response_class: 'ClientResponse' = None, - reason: Optional[str] = None) -> ClientResponse: - if response_class is None: - response_class = ClientResponse - if payload is not None: - body = json.dumps(payload) - if not isinstance(body, bytes): - body = str.encode(body) - kwargs = {} - if AIOHTTP_VERSION >= StrictVersion('3.1.0'): - loop = Mock() - loop.get_debug = Mock() - loop.get_debug.return_value = True - kwargs['request_info'] = Mock() - kwargs['writer'] = Mock() - kwargs['continue100'] = None - kwargs['timer'] = TimerNoop() - if AIOHTTP_VERSION < StrictVersion('3.3.0'): - kwargs['auto_decompress'] = True - kwargs['traces'] = [] - kwargs['loop'] = loop - kwargs['session'] = None - _headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type}) - if headers: - _headers.update(headers) - raw_headers = _build_raw_headers(_headers) - resp = response_class(method, url, **kwargs) - if AIOHTTP_VERSION >= StrictVersion('3.3.0'): - # Reified attributes - resp._headers = _headers - resp._raw_headers = raw_headers - else: - resp.headers = _headers - resp.raw_headers = raw_headers - resp.status = status - resp.reason = reason - resp.content = stream_reader_factory() - resp.content.feed_data(body) - resp.content.feed_eof() - return resp - - -class CallbackResult: - - def __init__(self, url: 'Union[URL, str]', - method: str = hdrs.METH_GET, - status: int = 200, - body: str = '', - content_type: str = 'application/json', - payload: Dict = None, - headers: Dict = None, - response_class: 'ClientResponse' = None, - reason: Optional[str] = None): - self.url = url - self.method = method - self.status = status - self.body = body - self.content_type = content_type - self.payload = payload - self.headers = headers - self.response_class = response_class - self.reason = reason - - __all__ = [ 'URL', 'Pattern', @@ -146,6 +54,4 @@ def __init__(self, url: 'Union[URL, str]', 'merge_params', 'stream_reader_factory', 'normalize_url', - 'build_response', - 'CallbackResult', ] diff --git a/aioresponses/core.py b/aioresponses/core.py index c3baa3e..8e502be 100644 --- a/aioresponses/core.py +++ b/aioresponses/core.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- import asyncio +import json from collections import namedtuple +from distutils.version import StrictVersion from functools import wraps from typing import Callable, Dict, Tuple, Union, Optional, List # noqa -from unittest.mock import patch +from unittest.mock import Mock, patch from aiohttp import ( ClientConnectionError, @@ -12,16 +14,39 @@ hdrs, http ) +from aiohttp.helpers import TimerNoop +from multidict import CIMultiDict from .compat import ( + AIOHTTP_VERSION, URL, Pattern, + stream_reader_factory, merge_params, normalize_url, - build_response, ) +class CallbackResult: + + def __init__(self, method: str = hdrs.METH_GET, + status: int = 200, + body: str = '', + content_type: str = 'application/json', + payload: Dict = None, + headers: Dict = None, + response_class: 'ClientResponse' = None, + reason: Optional[str] = None): + self.method = method + self.status = status + self.body = body + self.content_type = content_type + self.payload = payload + self.headers = headers + self.response_class = response_class + self.reason = reason + + class RequestMatch(object): url_or_pattern = None # type: Union[URL, Pattern] @@ -74,31 +99,86 @@ def match(self, method: str, url: URL) -> bool: return False return self.match_func(url) - async def _build_response( + def _build_raw_headers(self, headers: Dict) -> Tuple: + """ + Convert a dict of headers to a tuple of tuples + + Mimics the format of ClientResponse. + """ + raw_headers = [] + for k, v in headers.items(): + raw_headers.append((k.encode('utf8'), v.encode('utf8'))) + return tuple(raw_headers) + + def _build_response(self, url: 'Union[URL, str]', + method: str = hdrs.METH_GET, + status: int = 200, + body: str = '', + content_type: str = 'application/json', + payload: Dict = None, + headers: Dict = None, + response_class: 'ClientResponse' = None, + reason: Optional[str] = None) -> ClientResponse: + if response_class is None: + response_class = ClientResponse + if payload is not None: + body = json.dumps(payload) + if not isinstance(body, bytes): + body = str.encode(body) + kwargs = {} + if AIOHTTP_VERSION >= StrictVersion('3.1.0'): + loop = Mock() + loop.get_debug = Mock() + loop.get_debug.return_value = True + kwargs['request_info'] = Mock() + kwargs['writer'] = Mock() + kwargs['continue100'] = None + kwargs['timer'] = TimerNoop() + if AIOHTTP_VERSION < StrictVersion('3.3.0'): + kwargs['auto_decompress'] = True + kwargs['traces'] = [] + kwargs['loop'] = loop + kwargs['session'] = None + # We need to initialize headers manually + _headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type}) + if headers: + _headers.update(headers) + raw_headers = self._build_raw_headers(_headers) + resp = response_class(method, url, **kwargs) + if AIOHTTP_VERSION >= StrictVersion('3.3.0'): + # Reified attributes + resp._headers = _headers + resp._raw_headers = raw_headers + else: + resp.headers = _headers + resp.raw_headers = raw_headers + resp.status = status + resp.reason = reason + resp.content = stream_reader_factory() + resp.content.feed_data(body) + resp.content.feed_eof() + return resp + + async def build_response( self, url: URL, **kwargs: Dict ) -> 'Union[ClientResponse, Exception]': if isinstance(self.exception, Exception): return self.exception if callable(self.callback): - resp = self.callback(url, **kwargs) - else: - resp = None - if resp is None: - url = url - resp_data = self + result = self.callback(url, **kwargs) else: - url = resp.url - resp_data = resp - resp = build_response( + result = None + result = self if result is None else result + resp = self._build_response( url=url, - method=resp_data.method, - status=resp_data.status, - body=resp_data.body, - content_type=resp_data.content_type, - payload=resp_data.payload, - headers=resp_data.headers, - response_class=resp_data.response_class, - reason=resp_data.reason) + method=result.method, + status=result.status, + body=result.body, + content_type=result.content_type, + payload=result.payload, + headers=result.headers, + response_class=result.response_class, + reason=result.reason) return resp @@ -218,7 +298,7 @@ async def match( ) -> Optional['ClientResponse']: for i, matcher in enumerate(self._matches): if matcher.match(method, url): - response = await matcher._build_response(url, **kwargs) + response = await matcher.build_response(url, **kwargs) break else: return None diff --git a/tests/test_aioresponses.py b/tests/test_aioresponses.py index 45cf2c5..5652faf 100644 --- a/tests/test_aioresponses.py +++ b/tests/test_aioresponses.py @@ -322,7 +322,7 @@ def test_callback(self, m): def callback(url, **kwargs): self.assertEqual(str(url), self.url) self.assertEqual(kwargs, {'allow_redirects': True}) - return CallbackResult(url, body=body) + return CallbackResult(body=body) m.get(self.url, callback=callback) response = self.run_async(self.request(self.url))