Skip to content

Commit

Permalink
Merge pull request #116 from decaz/feature/callbacks
Browse files Browse the repository at this point in the history
Add callbacks to provide dynamic responses
  • Loading branch information
pnuckowski authored Jan 19, 2019
2 parents 3db2453 + 7b672ce commit 566461a
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 50 deletions.
24 changes: 22 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,28 @@ for convenience use *payload* argument to mock out json response. Example below.
# will throw an exception.
**aioresponses allows to use callbacks to provide dynamic responses**

.. code:: python
import asyncio
import aiohttp
from aioresponses import CallbackResult, aioresponses
def callback(url, **kwargs):
return CallbackResult(status=418)
@aioresponses()
def test_callback(m, test_client):
loop = asyncio.get_event_loop()
session = ClientSession()
m.get('http://example.com', callback=callback)
resp = loop.run_until_complete(session.get('http://example.com'))
assert resp.status == 418
**aioresponses can be used in a pytest fixture**

.. code:: python
Expand All @@ -203,8 +225,6 @@ for convenience use *payload* argument to mock out json response. Example below.
yield m
Features
--------
* Easy to mock out HTTP requests made by *aiohttp.ClientSession*
Expand Down
3 changes: 2 additions & 1 deletion aioresponses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
from .core import aioresponses
from .core import CallbackResult, aioresponses

__version__ = '0.4.1'

__all__ = [
'CallbackResult',
'aioresponses',
]
8 changes: 3 additions & 5 deletions aioresponses/compat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
import asyncio
import asyncio # noqa
import re
from distutils.version import StrictVersion
from typing import Dict, Optional, Union # noqa
from typing import Dict, Optional, Tuple, Union # noqa
from urllib.parse import parse_qsl, urlencode

from aiohttp import StreamReader
from aiohttp import __version__ as aiohttp_version
from aiohttp import __version__ as aiohttp_version, StreamReader
from multidict import MultiDict
from yarl import URL

Expand All @@ -27,7 +26,6 @@ def stream_reader_factory(
protocol = ResponseHandler(loop=loop)
return StreamReader(protocol)


else: # pragma: no cover

def stream_reader_factory():
Expand Down
135 changes: 94 additions & 41 deletions aioresponses/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import namedtuple
from distutils.version import StrictVersion
from functools import wraps
from typing import Dict, Tuple, Union, Optional, List # noqa
from typing import Callable, Dict, Tuple, Union, Optional, List # noqa
from unittest.mock import Mock, patch

from aiohttp import (
Expand All @@ -18,15 +18,35 @@
from multidict import CIMultiDict

from .compat import (
AIOHTTP_VERSION,
URL,
Pattern,
stream_reader_factory,
merge_params,
normalize_url,
AIOHTTP_VERSION
)


class CallbackResult:

def __init__(self, method: str = hdrs.METH_GET,
status: int = 200,
body: str = '',
content_type: str = 'application/json',
payload: Dict = None,
headers: Dict = None,
response_class: 'ClientResponse' = None,
reason: Optional[str] = None):
self.method = method
self.status = status
self.body = body
self.content_type = content_type
self.payload = payload
self.headers = headers
self.response_class = response_class
self.reason = reason


class RequestMatch(object):
url_or_pattern = None # type: Union[URL, Pattern]

Expand All @@ -41,7 +61,8 @@ def __init__(self, url: Union[str, Pattern],
response_class: 'ClientResponse' = None,
timeout: bool = False,
repeat: bool = False,
reason: Optional[str] = None):
reason: Optional[str] = None,
callback: Optional[Callable] = None):
if isinstance(url, Pattern):
self.url_or_pattern = url
self.match_func = self.match_regexp
Expand All @@ -50,24 +71,22 @@ def __init__(self, url: Union[str, Pattern],
self.match_func = self.match_str
self.method = method.lower()
self.status = status
if payload is not None:
body = json.dumps(payload)
if not isinstance(body, bytes):
body = str.encode(body)
self.body = body
self.payload = payload
self.exception = exception
if timeout:
self.exception = asyncio.TimeoutError('Connection timeout test')
self.headers = headers
self.content_type = content_type
self.response_class = response_class or ClientResponse
self.response_class = response_class
self.repeat = repeat
self.reason = reason
if self.reason is None:
try:
self.reason = http.RESPONSES[self.status][0]
except (IndexError, KeyError):
self.reason = ''
self.callback = callback

def match_str(self, url: URL) -> bool:
return self.url_or_pattern == url
Expand All @@ -80,11 +99,32 @@ def match(self, method: str, url: URL) -> bool:
return False
return self.match_func(url)

async def build_response(
self, url: URL
) -> 'Union[ClientResponse, Exception]':
if isinstance(self.exception, Exception):
return self.exception
def _build_raw_headers(self, headers: Dict) -> Tuple:
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)

def _build_response(self, url: 'Union[URL, str]',
method: str = hdrs.METH_GET,
status: int = 200,
body: str = '',
content_type: str = 'application/json',
payload: Dict = None,
headers: Dict = None,
response_class: 'ClientResponse' = None,
reason: Optional[str] = None) -> ClientResponse:
if response_class is None:
response_class = ClientResponse
if payload is not None:
body = json.dumps(payload)
if not isinstance(body, bytes):
body = str.encode(body)
kwargs = {}
if AIOHTTP_VERSION >= StrictVersion('3.1.0'):
loop = Mock()
Expand All @@ -94,43 +134,52 @@ async def build_response(
kwargs['writer'] = Mock()
kwargs['continue100'] = None
kwargs['timer'] = TimerNoop()
if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
pass
else:
if AIOHTTP_VERSION < StrictVersion('3.3.0'):
kwargs['auto_decompress'] = True
kwargs['traces'] = []
kwargs['loop'] = loop
kwargs['session'] = None
resp = self.response_class(self.method, url, **kwargs)
# we need to initialize headers manually
headers = CIMultiDict({hdrs.CONTENT_TYPE: self.content_type})
if self.headers:
headers.update(self.headers)
raw_headers = self._build_raw_headers(headers)
# We need to initialize headers manually
_headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type})
if headers:
_headers.update(headers)
raw_headers = self._build_raw_headers(_headers)
resp = response_class(method, url, **kwargs)
if AIOHTTP_VERSION >= StrictVersion('3.3.0'):
# Reified attributes
resp._headers = headers
resp._headers = _headers
resp._raw_headers = raw_headers
else:
resp.headers = headers
resp.headers = _headers
resp.raw_headers = raw_headers
resp.status = self.status
resp.reason = self.reason
resp.status = status
resp.reason = reason
resp.content = stream_reader_factory()
resp.content.feed_data(self.body)
resp.content.feed_data(body)
resp.content.feed_eof()
return resp

def _build_raw_headers(self, headers: Dict) -> Tuple:
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)
async def build_response(
self, url: URL, **kwargs: Dict
) -> 'Union[ClientResponse, Exception]':
if isinstance(self.exception, Exception):
return self.exception
if callable(self.callback):
result = self.callback(url, **kwargs)
else:
result = None
result = self if result is None else result
resp = self._build_response(
url=url,
method=result.method,
status=result.status,
body=result.body,
content_type=result.content_type,
payload=result.payload,
headers=result.headers,
response_class=result.response_class,
reason=result.reason)
return resp


RequestCall = namedtuple('RequestCall', ['args', 'kwargs'])
Expand Down Expand Up @@ -226,7 +275,8 @@ def add(self, url: 'Union[URL, str]', method: str = hdrs.METH_GET,
response_class: 'ClientResponse' = None,
repeat: bool = False,
timeout: bool = False,
reason: Optional[str] = None) -> None:
reason: Optional[str] = None,
callback: Optional[Callable] = None) -> None:
self._matches.append(RequestMatch(
url,
method=method,
Expand All @@ -240,12 +290,15 @@ def add(self, url: 'Union[URL, str]', method: str = hdrs.METH_GET,
repeat=repeat,
timeout=timeout,
reason=reason,
callback=callback,
))

async def match(self, method: str, url: URL) -> Optional['ClientResponse']:
async def match(
self, method: str, url: URL, **kwargs: Dict
) -> Optional['ClientResponse']:
for i, matcher in enumerate(self._matches):
if matcher.match(method, url):
response = await matcher.build_response(url)
response = await matcher.build_response(url, **kwargs)
break
else:
return None
Expand All @@ -269,7 +322,7 @@ async def _request_mock(self, orig_self: ClientSession,
orig_self, method, url, *args, **kwargs
))

response = await self.match(method, url)
response = await self.match(method, url, **kwargs)
if response is None:
raise ClientConnectionError(
'Connection refused: {} {}'.format(method, url)
Expand Down
16 changes: 15 additions & 1 deletion tests/test_aioresponses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aiohttp.http_exceptions import HttpProcessingError

from aioresponses.compat import URL
from aioresponses import aioresponses
from aioresponses import CallbackResult, aioresponses


@ddt
Expand Down Expand Up @@ -314,3 +314,17 @@ def test_timeout(self, mocked):

with self.assertRaises(asyncio.TimeoutError):
self.run_async(self.request(self.url))

@aioresponses()
def test_callback(self, m):
body = b'New body'

def callback(url, **kwargs):
self.assertEqual(str(url), self.url)
self.assertEqual(kwargs, {'allow_redirects': True})
return CallbackResult(body=body)

m.get(self.url, callback=callback)
response = self.run_async(self.request(self.url))
data = self.run_async(response.read())
assert data == body
3 changes: 3 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ envlist =
py35-aiohttp{20,21,22,23,30,31,32,33,34,35}
py36-aiohttp-master
py36-aiohttp{20,21,22,23,30,31,32,33,34,35}
py37-aiohttp-master
py37-aiohttp{20,21,22,23,30,31,32,33,34,35}
skipsdist = True

[testenv:flake8]
Expand Down Expand Up @@ -38,6 +40,7 @@ deps =
basepython =
py35: python3.5
py36: python3.6
py37: python3.7

commands = python setup.py test

Expand Down

0 comments on commit 566461a

Please sign in to comment.