# -*- coding: utf-8 -*-
import asyncio
import copy
import inspect
import json
from collections import namedtuple
from functools import wraps
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock, patch
from uuid import uuid4
from aiohttp import (
ClientConnectionError,
ClientResponse,
ClientSession,
hdrs,
http
)
from aiohttp.helpers import TimerNoop
from multidict import CIMultiDict, CIMultiDictProxy
from pkg_resources import parse_version
from .compat import (
AIOHTTP_VERSION,
URL,
Pattern,
stream_reader_factory,
merge_params,
normalize_url,
RequestInfo,
)
_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])
[docs]class CallbackResult:
def __init__(self, method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
reason: Optional[str] = None):
self.method = method
self.status = status
self.body = body
self.content_type = content_type
self.payload = payload
self.headers = headers
self.response_class = response_class
self.reason = reason
[docs]class RequestMatch(object):
url_or_pattern = None # type: Union[URL, Pattern]
def __init__(self, url: Union[URL, str, Pattern],
method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
payload: Optional[Dict] = None,
exception: Optional[Exception] = None,
headers: Optional[Dict] = None,
content_type: str = 'application/json',
response_class: Optional[Type[ClientResponse]] = None,
timeout: bool = False,
repeat: bool = False,
reason: Optional[str] = None,
callback: Optional[Callable] = None):
if isinstance(url, Pattern):
self.url_or_pattern = url
self.match_func = self.match_regexp
else:
self.url_or_pattern = normalize_url(url)
self.match_func = self.match_str
self.method = method.lower()
self.status = status
self.body = body
self.payload = payload
self.exception = exception
if timeout:
self.exception = asyncio.TimeoutError('Connection timeout test')
self.headers = headers
self.content_type = content_type
self.response_class = response_class
self.repeat = repeat
self.reason = reason
if self.reason is None:
try:
self.reason = http.RESPONSES[self.status][0]
except (IndexError, KeyError):
self.reason = ''
self.callback = callback
[docs] def match_str(self, url: URL) -> bool:
return self.url_or_pattern == url
[docs] def match_regexp(self, url: URL) -> bool:
# This method is used if and only if self.url_or_pattern is a pattern.
return bool(
self.url_or_pattern.match(str(url)) # type:ignore[union-attr]
)
[docs] def match(self, method: str, url: URL) -> bool:
if self.method != method.lower():
return False
return self.match_func(url)
def _build_raw_headers(self, headers: Dict) -> Tuple:
"""
Convert a dict of headers to a tuple of tuples
Mimics the format of ClientResponse.
"""
raw_headers = []
for k, v in headers.items():
raw_headers.append((k.encode('utf8'), v.encode('utf8')))
return tuple(raw_headers)
def _build_response(self, url: 'Union[URL, str]',
method: str = hdrs.METH_GET,
request_headers: Optional[Dict] = None,
status: int = 200,
body: Union[str, bytes] = '',
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
reason: Optional[str] = None) -> ClientResponse:
if response_class is None:
response_class = ClientResponse
if payload is not None:
body = json.dumps(payload)
if not isinstance(body, bytes):
body = str.encode(body)
if request_headers is None:
request_headers = {}
kwargs = {} # type: Dict[str, Any]
if AIOHTTP_VERSION >= parse_version('3.1.0'):
loop = Mock()
loop.get_debug = Mock()
loop.get_debug.return_value = True
kwargs['request_info'] = RequestInfo(
url=url,
method=method,
headers=CIMultiDictProxy(CIMultiDict(**request_headers)),
)
kwargs['writer'] = Mock()
kwargs['continue100'] = None
kwargs['timer'] = TimerNoop()
if AIOHTTP_VERSION < parse_version('3.3.0'):
kwargs['auto_decompress'] = True
kwargs['traces'] = []
kwargs['loop'] = loop
kwargs['session'] = None
else:
loop = None
# We need to initialize headers manually
_headers = CIMultiDict({hdrs.CONTENT_TYPE: content_type})
if headers:
_headers.update(headers)
raw_headers = self._build_raw_headers(_headers)
resp = response_class(method, url, **kwargs)
for hdr in _headers.getall(hdrs.SET_COOKIE, ()):
resp.cookies.load(hdr)
if AIOHTTP_VERSION >= parse_version('3.3.0'):
# Reified attributes
resp._headers = _headers
resp._raw_headers = raw_headers
else:
resp.headers = _headers
resp.raw_headers = raw_headers
resp.status = status
resp.reason = reason
resp.content = stream_reader_factory(loop)
resp.content.feed_data(body)
resp.content.feed_eof()
return resp
[docs] async def build_response(
self, url: URL, **kwargs: Any
) -> 'Union[ClientResponse, Exception]':
if callable(self.callback):
if asyncio.iscoroutinefunction(self.callback):
result = await self.callback(url, **kwargs)
else:
result = self.callback(url, **kwargs)
else:
result = None
if self.exception is not None:
return self.exception
result = self if result is None else result
resp = self._build_response(
url=url,
method=result.method,
request_headers=kwargs.get("headers"),
status=result.status,
body=result.body,
content_type=result.content_type,
payload=result.payload,
headers=result.headers,
response_class=result.response_class,
reason=result.reason)
return resp
RequestCall = namedtuple('RequestCall', ['args', 'kwargs'])
[docs]class aioresponses(object):
"""Mock aiohttp requests made by ClientSession."""
_matches = None # type: Dict[str, RequestMatch]
_responses = None # type: List[ClientResponse]
requests = None # type: Dict
def __init__(self, **kwargs: Any):
self._param = kwargs.pop('param', None)
self._passthrough = kwargs.pop('passthrough', [])
self.patcher = patch('aiohttp.client.ClientSession._request',
side_effect=self._request_mock,
autospec=True)
self.requests = {}
def __enter__(self) -> 'aioresponses':
self.start()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.stop()
def __call__(self, f: _FuncT) -> _FuncT:
def _pack_arguments(ctx, *args, **kwargs) -> Tuple[Tuple, Dict]:
if self._param:
kwargs[self._param] = ctx
else:
args += (ctx,)
return args, kwargs
if asyncio.iscoroutinefunction(f):
@wraps(f)
async def wrapped(*args, **kwargs):
with self as ctx:
args, kwargs = _pack_arguments(ctx, *args, **kwargs)
return await f(*args, **kwargs)
else:
@wraps(f)
def wrapped(*args, **kwargs):
with self as ctx:
args, kwargs = _pack_arguments(ctx, *args, **kwargs)
return f(*args, **kwargs)
return cast(_FuncT, wrapped)
[docs] def clear(self) -> None:
self._responses.clear()
self._matches.clear()
[docs] def start(self) -> None:
self._responses = []
self._matches = {}
self.patcher.start()
self.patcher.return_value = self._request_mock
[docs] def stop(self) -> None:
for response in self._responses:
response.close()
self.patcher.stop()
self.clear()
[docs] def head(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_HEAD, **kwargs)
[docs] def get(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_GET, **kwargs)
[docs] def post(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_POST, **kwargs)
[docs] def put(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_PUT, **kwargs)
[docs] def patch(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_PATCH, **kwargs)
[docs] def delete(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_DELETE, **kwargs)
[docs] def options(self, url: 'Union[URL, str, Pattern]', **kwargs: Any) -> None:
self.add(url, method=hdrs.METH_OPTIONS, **kwargs)
[docs] def add(self, url: 'Union[URL, str, Pattern]', method: str = hdrs.METH_GET,
status: int = 200,
body: Union[str, bytes] = '',
exception: Optional[Exception] = None,
content_type: str = 'application/json',
payload: Optional[Dict] = None,
headers: Optional[Dict] = None,
response_class: Optional[Type[ClientResponse]] = None,
repeat: bool = False,
timeout: bool = False,
reason: Optional[str] = None,
callback: Optional[Callable] = None) -> None:
self._matches[str(uuid4())] = (RequestMatch(
url,
method=method,
status=status,
content_type=content_type,
body=body,
exception=exception,
payload=payload,
headers=headers,
response_class=response_class,
repeat=repeat,
timeout=timeout,
reason=reason,
callback=callback,
))
def _format_call_signature(self, *args, **kwargs) -> str:
message = '%s(%%s)' % self.__class__.__name__ or 'mock'
formatted_args = ''
args_string = ', '.join([repr(arg) for arg in args])
kwargs_string = ', '.join([
'%s=%r' % (key, value) for key, value in kwargs.items()
])
if args_string:
formatted_args = args_string
if kwargs_string:
if formatted_args:
formatted_args += ', '
formatted_args += kwargs_string
return message % formatted_args
[docs] def assert_not_called(self):
"""assert that the mock was never called.
"""
if len(self.requests) != 0:
msg = ("Expected '%s' to not have been called. Called %s times."
% (self.__class__.__name__,
len(self._responses)))
raise AssertionError(msg)
[docs] def assert_called(self):
"""assert that the mock was called at least once.
"""
if len(self.requests) == 0:
msg = ("Expected '%s' to have been called."
% (self.__class__.__name__,))
raise AssertionError(msg)
[docs] def assert_called_once(self):
"""assert that the mock was called only once.
"""
call_count = len(self.requests)
if call_count == 1:
call_count = len(list(self.requests.values())[0])
if not call_count == 1:
msg = ("Expected '%s' to have been called once. Called %s times."
% (self.__class__.__name__,
call_count))
raise AssertionError(msg)
[docs] def assert_called_with(self, url: 'Union[URL, str, Pattern]',
method: str = hdrs.METH_GET,
*args: Any,
**kwargs: Any):
"""assert that the last call was made with the specified arguments.
Raises an AssertionError if the args and keyword args passed in are
different to the last call to the mock."""
url = normalize_url(merge_params(url, kwargs.get('params')))
method = method.upper()
key = (method, url)
try:
expected = self.requests[key][-1]
except KeyError:
expected_string = self._format_call_signature(
url, method=method, *args, **kwargs
)
raise AssertionError(
'%s call not found' % expected_string
)
actual = self._build_request_call(method, *args, **kwargs)
if not expected == actual:
expected_string = self._format_call_signature(
expected,
)
actual_string = self._format_call_signature(
actual
)
raise AssertionError(
'%s != %s' % (expected_string, actual_string)
)
[docs] def assert_any_call(self, url: 'Union[URL, str, Pattern]',
method: str = hdrs.METH_GET,
*args: Any,
**kwargs: Any):
"""assert the mock has been called with the specified arguments.
The assert passes if the mock has *ever* been called, unlike
`assert_called_with` and `assert_called_once_with` that only pass if
the call is the most recent one."""
url = normalize_url(merge_params(url, kwargs.get('params')))
method = method.upper()
key = (method, url)
try:
self.requests[key]
except KeyError:
expected_string = self._format_call_signature(
url, method=method, *args, **kwargs
)
raise AssertionError(
'%s call not found' % expected_string
)
[docs] def assert_called_once_with(self, *args: Any, **kwargs: Any):
"""assert that the mock was called once with the specified arguments.
Raises an AssertionError if the args and keyword args passed in are
different to the only call to the mock."""
self.assert_called_once()
self.assert_called_with(*args, **kwargs)
[docs] @staticmethod
def is_exception(resp_or_exc: Union[ClientResponse, Exception]) -> bool:
if inspect.isclass(resp_or_exc):
parent_classes = set(inspect.getmro(resp_or_exc))
if {Exception, BaseException} & parent_classes:
return True
else:
if isinstance(resp_or_exc, (Exception, BaseException)):
return True
return False
[docs] async def match(
self, method: str,
url: URL,
allow_redirects: bool = True,
**kwargs: Any
) -> Optional['ClientResponse']:
history = []
while True:
for key, matcher in self._matches.items():
if matcher.match(method, url):
response_or_exc = await matcher.build_response(
url, allow_redirects=allow_redirects, **kwargs
)
break
else:
return None
if matcher.repeat is False:
del self._matches[key]
if self.is_exception(response_or_exc):
raise response_or_exc
# If response_or_exc was an exception, it would have been raised.
# At this point we can be sure it's a ClientResponse
response: ClientResponse
response = response_or_exc # type:ignore[assignment]
is_redirect = response.status in (301, 302, 303, 307, 308)
if is_redirect and allow_redirects:
if hdrs.LOCATION not in response.headers:
break
history.append(response)
redirect_url = URL(response.headers[hdrs.LOCATION])
if redirect_url.is_absolute():
url = redirect_url
else:
url = url.join(redirect_url)
method = 'get'
continue
else:
break
response._history = tuple(history)
return response
async def _request_mock(self, orig_self: ClientSession,
method: str, url: 'Union[URL, str]',
*args: Tuple,
**kwargs: Any) -> 'ClientResponse':
"""Return mocked response object or raise connection error."""
if orig_self.closed:
raise RuntimeError('Session is closed')
url_origin = url
url = normalize_url(merge_params(url, kwargs.get('params')))
url_str = str(url)
for prefix in self._passthrough:
if url_str.startswith(prefix):
return (await self.patcher.temp_original(
orig_self, method, url_origin, *args, **kwargs
))
key = (method, url)
self.requests.setdefault(key, [])
request_call = self._build_request_call(method, *args, **kwargs)
self.requests[key].append(request_call)
response = await self.match(method, url, **kwargs)
if response is None:
raise ClientConnectionError(
'Connection refused: {} {}'.format(method, url)
)
self._responses.append(response)
# Automatically call response.raise_for_status() on a request if the
# request was initialized with raise_for_status=True. Also call
# response.raise_for_status() if the client session was initialized
# with raise_for_status=True, unless the request was called with
# raise_for_status=False.
raise_for_status = kwargs.get('raise_for_status')
if raise_for_status is None:
raise_for_status = getattr(
orig_self, '_raise_for_status', False
)
if raise_for_status:
response.raise_for_status()
return response
def _build_request_call(self, method: str = hdrs.METH_GET,
*args: Any,
allow_redirects: bool = True,
**kwargs: Any):
"""Return request call."""
kwargs.setdefault('allow_redirects', allow_redirects)
if method == 'POST':
kwargs.setdefault('data', None)
try:
kwargs_copy = copy.deepcopy(kwargs)
except (TypeError, ValueError):
# Handle the fact that some values cannot be deep copied
kwargs_copy = kwargs
return RequestCall(args, kwargs_copy)