Skip to content

Commit

Permalink
Expose build_response method
Browse files Browse the repository at this point in the history
Also remove `match` and `*args` params from callback signature
and make callback simple function instead of coroutine.
  • Loading branch information
decaz committed Jan 19, 2019
1 parent 8930a8b commit f7e99a2
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 88 deletions.
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,10 @@ for convenience use *payload* argument to mock out json response. Example below.
import asyncio
import aiohttp
from aioresponses import aioresponses
from aioresponses import aioresponses, build_response
async def callback(match, url, *args, **kwargs):
return match.build_response(url, status=418)
def callback(url, **kwargs):
return build_response(url, status=418)
@aioresponses()
def test_callback(m, test_client):
Expand Down
1 change: 1 addition & 0 deletions aioresponses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from .compat import build_response
from .core import aioresponses

__version__ = '0.4.1'
Expand Down
79 changes: 74 additions & 5 deletions aioresponses/compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# -*- coding: utf-8 -*-
import asyncio # noqa
import json
import re
from distutils.version import StrictVersion
from typing import Dict, Optional, Union # noqa
from typing import Dict, Optional, Tuple, Union # noqa
from unittest.mock import Mock
from urllib.parse import parse_qsl, urlencode

from aiohttp import StreamReader
from aiohttp import __version__ as aiohttp_version
from multidict import MultiDict
from aiohttp import (
__version__ as aiohttp_version,
ClientResponse,
StreamReader,
hdrs,
)
from aiohttp.helpers import TimerNoop
from multidict import CIMultiDict, MultiDict
from yarl import URL

try:
Expand All @@ -27,7 +34,6 @@ def stream_reader_factory(
protocol = ResponseHandler(loop=loop)
return StreamReader(protocol)


else: # pragma: no cover

def stream_reader_factory():
Expand All @@ -49,11 +55,74 @@ def normalize_url(url: 'Union[URL, str]') -> 'URL':
return url.with_query(urlencode(sorted(parse_qsl(url.query_string))))


def _build_raw_headers(headers: Dict) -> Tuple:
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)


def build_response(
url: 'Union[URL, str]',
method: str = hdrs.METH_GET,
status: int = 200,
body: str = '',
content_type: str = 'application/json',
payload: Dict = None,
headers: Dict = None,
response_class: 'ClientResponse' = None,
reason: Optional[str] = None) -> ClientResponse:
if response_class is None:
response_class = ClientResponse
if payload is not None:
body = json.dumps(payload)
if not isinstance(body, bytes):
body = str.encode(body)
kwargs = {}
if AIOHTTP_VERSION >= StrictVersion('3.1.0'):
loop = Mock()
loop.get_debug = Mock()
loop.get_debug.return_value = True
kwargs['request_info'] = Mock()
kwargs['writer'] = Mock()
kwargs['continue100'] = None
kwargs['timer'] = TimerNoop()
if AIOHTTP_VERSION < StrictVersion('3.3.0'):
kwargs['auto_decompress'] = True
kwargs['traces'] = []
kwargs['loop'] = loop
kwargs['session'] = None
_headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type})
if headers:
_headers.update(headers)
raw_headers = _build_raw_headers(_headers)
resp = response_class(method, url, **kwargs)
if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
# Reified attributes
resp._headers = _headers
resp._raw_headers = raw_headers
else:
resp.headers = _headers
resp.raw_headers = raw_headers
resp.status = status
resp.reason = reason
resp.content = stream_reader_factory()
resp.content.feed_data(body)
resp.content.feed_eof()
return resp


__all__ = [
'URL',
'Pattern',
'AIOHTTP_VERSION',
'merge_params',
'stream_reader_factory',
'normalize_url',
'build_response',
]
84 changes: 9 additions & 75 deletions aioresponses/core.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,23 @@
# -*- coding: utf-8 -*-
import asyncio
import json
from collections import namedtuple
from distutils.version import StrictVersion
from functools import wraps
from typing import Callable, Dict, Tuple, Union, Optional, List # noqa
from unittest.mock import Mock, patch
from unittest.mock import patch

from aiohttp import (
ClientConnectionError,
ClientResponse,
ClientSession,
hdrs,
http
)
from aiohttp.helpers import TimerNoop
from multidict import CIMultiDict

from .compat import (
URL,
Pattern,
stream_reader_factory,
merge_params,
normalize_url,
AIOHTTP_VERSION
build_response,
)


Expand Down Expand Up @@ -79,66 +73,17 @@ def match(self, method: str, url: URL) -> bool:
return False
return self.match_func(url)

def build_response(
self, url: 'Union[URL, str]',
method: str = hdrs.METH_GET,
status: int = 200,
body: str = '',
content_type: str = 'application/json',
payload: Dict = None,
headers: Dict = None,
response_class: 'ClientResponse' = None,
reason: Optional[str] = None) -> ClientResponse:
if response_class is None:
response_class = ClientResponse
if payload is not None:
body = json.dumps(payload)
if not isinstance(body, bytes):
body = str.encode(body)
kwargs = {}
if AIOHTTP_VERSION >= StrictVersion('3.1.0'):
loop = Mock()
loop.get_debug = Mock()
loop.get_debug.return_value = True
kwargs['request_info'] = Mock()
kwargs['writer'] = Mock()
kwargs['continue100'] = None
kwargs['timer'] = TimerNoop()
if AIOHTTP_VERSION < StrictVersion('3.3.0'):
kwargs['auto_decompress'] = True
kwargs['traces'] = []
kwargs['loop'] = loop
kwargs['session'] = None
_headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type})
if headers:
_headers.update(headers)
raw_headers = self._build_raw_headers(_headers)
resp = response_class(method, url, **kwargs)
if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
# Reified attributes
resp._headers = _headers
resp._raw_headers = raw_headers
else:
resp.headers = _headers
resp.raw_headers = raw_headers
resp.status = status
resp.reason = reason
resp.content = stream_reader_factory()
resp.content.feed_data(body)
resp.content.feed_eof()
return resp

async def _build_response(
self, url: URL, *args: Tuple, **kwargs: Dict
self, url: URL, **kwargs: Dict
) -> 'Union[ClientResponse, Exception]':
if isinstance(self.exception, Exception):
return self.exception
if self.callback and callable(self.callback):
resp = await self.callback(self, url, *args, **kwargs)
if callable(self.callback):
resp = self.callback(url, **kwargs)
else:
resp = None
if resp is None:
resp = self.build_response(
resp = build_response(
url=url,
method=self.method,
status=self.status,
Expand All @@ -150,17 +95,6 @@ async def _build_response(
reason=self.reason)
return resp

def _build_raw_headers(self, headers: Dict) -> Tuple:
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)


RequestCall = namedtuple('RequestCall', ['args', 'kwargs'])

Expand Down Expand Up @@ -274,11 +208,11 @@ def add(self, url: 'Union[URL, str]', method: str = hdrs.METH_GET,
))

async def match(
self, method: str, url: URL, *args: Tuple, **kwargs: Dict
self, method: str, url: URL, **kwargs: Dict
) -> Optional['ClientResponse']:
for i, matcher in enumerate(self._matches):
if matcher.match(method, url):
response = await matcher._build_response(url, *args, **kwargs)
response = await matcher._build_response(url, **kwargs)
break
else:
return None
Expand All @@ -302,7 +236,7 @@ async def _request_mock(self, orig_self: ClientSession,
orig_self, method, url, *args, **kwargs
))

response = await self.match(method, url, *args, **kwargs)
response = await self.match(method, url, **kwargs)
if response is None:
raise ClientConnectionError(
'Connection refused: {} {}'.format(method, url)
Expand Down
8 changes: 3 additions & 5 deletions tests/test_aioresponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from aiohttp.http_exceptions import HttpProcessingError

from aioresponses.compat import URL
from aioresponses.compat import URL, build_response
from aioresponses import aioresponses


Expand Down Expand Up @@ -319,12 +319,10 @@ def test_timeout(self, mocked):
def test_callback(self, m):
body = b'New body'

@asyncio.coroutine
async def callback(match, url, *args, **kwargs):
def callback(url, **kwargs):
self.assertEqual(str(url), self.url)
self.assertEqual(args, ())
self.assertEqual(kwargs, {'allow_redirects': True})
return match.build_response(url, body=body)
return build_response(url, body=body)

m.get(self.url, callback=callback)
response = self.run_async(self.request(self.url))
Expand Down

0 comments on commit f7e99a2

Please sign in to comment.