diff --git a/falcon/app.py b/falcon/app.py index f71a291f6..53cffe397 100644 --- a/falcon/app.py +++ b/falcon/app.py @@ -674,7 +674,7 @@ def add_static_route( self._static_routes.insert(0, (sr, sr, False)) self._update_sink_and_static_routes() - def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/'): + def add_sink(self, sink: Callable, prefix: SinkPrefix = r'/') -> None: """Register a sink method for the App. If no route matches a request, but the path in the requested URI diff --git a/falcon/app_helpers.py b/falcon/app_helpers.py index 1f0730f86..db6d7cc24 100644 --- a/falcon/app_helpers.py +++ b/falcon/app_helpers.py @@ -14,6 +14,8 @@ """Utilities for the App class.""" +from __future__ import annotations + from inspect import iscoroutinefunction from typing import IO, Iterable, List, Tuple @@ -202,7 +204,7 @@ def prepare_middleware_ws(middleware: Iterable) -> Tuple[list, list]: return request_mw, resource_mw -def default_serialize_error(req: Request, resp: Response, exception: HTTPError): +def default_serialize_error(req: Request, resp: Response, exception: HTTPError) -> None: """Serialize the given instance of HTTPError. This function determines which of the supported media types, if @@ -281,14 +283,14 @@ class CloseableStreamIterator: block_size (int): Number of bytes to read per iteration. """ - def __init__(self, stream: IO, block_size: int): + def __init__(self, stream: IO, block_size: int) -> None: self._stream = stream self._block_size = block_size - def __iter__(self): + def __iter__(self) -> CloseableStreamIterator: return self - def __next__(self): + def __next__(self) -> bytes: data = self._stream.read(self._block_size) if data == b'': @@ -296,7 +298,7 @@ def __next__(self): else: return data - def close(self): + def close(self) -> None: try: self._stream.close() except (AttributeError, TypeError): diff --git a/falcon/asgi_spec.py b/falcon/asgi_spec.py index 5e80a7241..bca282b37 100644 --- a/falcon/asgi_spec.py +++ b/falcon/asgi_spec.py @@ -14,6 +14,8 @@ """Constants, etc. defined by the ASGI specification.""" +from __future__ import annotations + class EventType: """Standard ASGI event type strings.""" diff --git a/falcon/errors.py b/falcon/errors.py index 74a50c69c..afc1faa8a 100644 --- a/falcon/errors.py +++ b/falcon/errors.py @@ -34,14 +34,21 @@ def on_get(self, req, resp): # -- snip -- """ +from __future__ import annotations + from datetime import datetime -from typing import Optional +from typing import Iterable, Optional, TYPE_CHECKING, Union from falcon.http_error import HTTPError import falcon.status_codes as status from falcon.util.deprecation import deprecated_args from falcon.util.misc import dt_to_http +if TYPE_CHECKING: + from falcon.typing import HeaderList + from falcon.typing import Headers + + __all__ = ( 'CompatibilityError', 'DelimiterError', @@ -142,7 +149,7 @@ class WebSocketDisconnected(ConnectionError): code (int): The WebSocket close code, as per the WebSocket spec. """ - def __init__(self, code: Optional[int] = None): + def __init__(self, code: Optional[int] = None) -> None: self.code = code or 1000 # Default to "Normal Closure" @@ -168,6 +175,10 @@ class WebSocketServerError(WebSocketDisconnected): pass +HTTPErrorKeywordArguments = Union[str, int, None] +RetryAfter = Union[int, datetime, None] + + class HTTPBadRequest(HTTPError): """400 Bad Request. @@ -214,7 +225,13 @@ class HTTPBadRequest(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ) -> None: super().__init__( status.HTTP_400, title=title, @@ -292,7 +309,12 @@ class HTTPUnauthorized(HTTPError): @deprecated_args(allowed_positional=0) def __init__( - self, title=None, description=None, headers=None, challenges=None, **kwargs + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + challenges: Optional[Iterable[str]] = None, + **kwargs: HTTPErrorKeywordArguments, ): if challenges: headers = _load_headers(headers) @@ -364,7 +386,13 @@ class HTTPForbidden(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_403, title=title, @@ -429,7 +457,13 @@ class HTTPNotFound(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_404, title=title, @@ -550,7 +584,12 @@ class HTTPMethodNotAllowed(HTTPError): @deprecated_args(allowed_positional=1) def __init__( - self, allowed_methods, title=None, description=None, headers=None, **kwargs + self, + allowed_methods: Iterable[str], + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, ): headers = _load_headers(headers) headers['Allow'] = ', '.join(allowed_methods) @@ -616,7 +655,13 @@ class HTTPNotAcceptable(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_406, title=title, @@ -683,7 +728,13 @@ class HTTPConflict(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_409, title=title, @@ -756,7 +807,13 @@ class HTTPGone(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_410, title=title, @@ -814,7 +871,13 @@ class HTTPLengthRequired(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_411, title=title, @@ -873,7 +936,13 @@ class HTTPPreconditionFailed(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_412, title=title, @@ -943,8 +1012,13 @@ class HTTPPayloadTooLarge(HTTPError): @deprecated_args(allowed_positional=0) def __init__( - self, title=None, description=None, retry_after=None, headers=None, **kwargs - ): + self, + title: Optional[str] = None, + description: Optional[str] = None, + retry_after: RetryAfter = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ) -> None: super().__init__( status.HTTP_413, title=title, @@ -1008,7 +1082,13 @@ class HTTPUriTooLong(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_414, title=title, @@ -1067,7 +1147,13 @@ class HTTPUnsupportedMediaType(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_415, title=title, @@ -1140,7 +1226,12 @@ class HTTPRangeNotSatisfiable(HTTPError): @deprecated_args(allowed_positional=1) def __init__( - self, resource_length, title=None, description=None, headers=None, **kwargs + self, + resource_length: int, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, ): headers = _load_headers(headers) headers['Content-Range'] = 'bytes */' + str(resource_length) @@ -1205,7 +1296,13 @@ class HTTPUnprocessableEntity(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_422, title=title, @@ -1261,7 +1358,13 @@ class HTTPLocked(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_423, title=title, @@ -1316,7 +1419,13 @@ class HTTPFailedDependency(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_424, title=title, @@ -1379,7 +1488,13 @@ class HTTPPreconditionRequired(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_428, title=title, @@ -1448,7 +1563,12 @@ class HTTPTooManyRequests(HTTPError): @deprecated_args(allowed_positional=0) def __init__( - self, title=None, description=None, headers=None, retry_after=None, **kwargs + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + retry_after: RetryAfter = None, + **kwargs: HTTPErrorKeywordArguments, ): super().__init__( status.HTTP_429, @@ -1511,7 +1631,13 @@ class HTTPRequestHeaderFieldsTooLarge(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_431, title=title, @@ -1580,7 +1706,13 @@ class HTTPUnavailableForLegalReasons(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_451, title=title, @@ -1635,7 +1767,13 @@ class HTTPInternalServerError(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_500, title=title, @@ -1697,7 +1835,13 @@ class HTTPNotImplemented(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_501, title=title, @@ -1752,7 +1896,13 @@ class HTTPBadGateway(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_502, title=title, @@ -1824,7 +1974,12 @@ class HTTPServiceUnavailable(HTTPError): @deprecated_args(allowed_positional=0) def __init__( - self, title=None, description=None, headers=None, retry_after=None, **kwargs + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + retry_after: RetryAfter = None, + **kwargs: HTTPErrorKeywordArguments, ): super().__init__( status.HTTP_503, @@ -1881,7 +2036,13 @@ class HTTPGatewayTimeout(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_504, title=title, @@ -1942,7 +2103,13 @@ class HTTPVersionNotSupported(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_505, title=title, @@ -2001,7 +2168,13 @@ class HTTPInsufficientStorage(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_507, title=title, @@ -2057,7 +2230,13 @@ class HTTPLoopDetected(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_508, title=title, @@ -2125,7 +2304,13 @@ class HTTPNetworkAuthenticationRequired(HTTPError): """ @deprecated_args(allowed_positional=0) - def __init__(self, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): super().__init__( status.HTTP_511, title=title, @@ -2178,7 +2363,13 @@ class HTTPInvalidHeader(HTTPBadRequest): """ @deprecated_args(allowed_positional=2) - def __init__(self, msg, header_name, headers=None, **kwargs): + def __init__( + self, + msg: str, + header_name: str, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): description = 'The value provided for the "{0}" header is invalid. {1}' description = description.format(header_name, msg) @@ -2232,7 +2423,12 @@ class HTTPMissingHeader(HTTPBadRequest): """ @deprecated_args(allowed_positional=1) - def __init__(self, header_name, headers=None, **kwargs): + def __init__( + self, + header_name: str, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ): description = 'The "{0}" header is required.' description = description.format(header_name) @@ -2289,7 +2485,13 @@ class HTTPInvalidParam(HTTPBadRequest): """ @deprecated_args(allowed_positional=2) - def __init__(self, msg, param_name, headers=None, **kwargs): + def __init__( + self, + msg: str, + param_name: str, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ) -> None: description = 'The "{0}" parameter is invalid. {1}' description = description.format(param_name, msg) @@ -2345,7 +2547,12 @@ class HTTPMissingParam(HTTPBadRequest): """ @deprecated_args(allowed_positional=1) - def __init__(self, param_name, headers=None, **kwargs): + def __init__( + self, + param_name: str, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ) -> None: description = 'The "{0}" parameter is required.' description = description.format(param_name) @@ -2396,7 +2603,7 @@ class MediaNotFoundError(HTTPBadRequest): base articles related to this error (default ``None``). """ - def __init__(self, media_type, **kwargs): + def __init__(self, media_type: str, **kwargs: HTTPErrorKeywordArguments) -> None: super().__init__( title='Invalid {0}'.format(media_type), description='Could not parse an empty {0} body'.format(media_type), @@ -2441,21 +2648,23 @@ class MediaMalformedError(HTTPBadRequest): base articles related to this error (default ``None``). """ - def __init__(self, media_type, **kwargs): + def __init__( + self, media_type: str, **kwargs: Union[HeaderList, HTTPErrorKeywordArguments] + ): super().__init__( title='Invalid {0}'.format(media_type), description=None, **kwargs ) self._media_type = media_type @property - def description(self): + def description(self) -> Optional[str]: msg = 'Could not parse {} body'.format(self._media_type) if self.__cause__ is not None: msg += ' - {}'.format(self.__cause__) return msg @description.setter - def description(self, value): + def description(self, value: str) -> None: pass @@ -2504,7 +2713,14 @@ class MediaValidationError(HTTPBadRequest): base articles related to this error (default ``None``). """ - def __init__(self, *, title=None, description=None, headers=None, **kwargs): + def __init__( + self, + *, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + **kwargs: HTTPErrorKeywordArguments, + ) -> None: super().__init__( title=title, description=description, @@ -2534,7 +2750,11 @@ class MultipartParseError(MediaMalformedError): description = None @deprecated_args(allowed_positional=0) - def __init__(self, description=None, **kwargs): + def __init__( + self, + description: Optional[str] = None, + **kwargs: Union[HeaderList, HTTPErrorKeywordArguments], + ) -> None: HTTPBadRequest.__init__( self, title='Malformed multipart/form-data request media', @@ -2548,7 +2768,7 @@ def __init__(self, description=None, **kwargs): # ----------------------------------------------------------------------------- -def _load_headers(headers): +def _load_headers(headers: Optional[HeaderList]) -> Headers: """Transform the headers to dict.""" if headers is None: return {} @@ -2557,7 +2777,10 @@ def _load_headers(headers): return dict(headers) -def _parse_retry_after(headers, retry_after): +def _parse_retry_after( + headers: Optional[HeaderList], + retry_after: RetryAfter, +) -> Optional[HeaderList]: """Set the Retry-After to the headers when required.""" if retry_after is None: return headers diff --git a/falcon/forwarded.py b/falcon/forwarded.py index 9fe3c03f8..f855b3661 100644 --- a/falcon/forwarded.py +++ b/falcon/forwarded.py @@ -16,9 +16,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import re import string +from typing import List, Optional from falcon.util.uri import unquote_string @@ -75,14 +77,14 @@ class Forwarded: # falcon.Request interface. __slots__ = ('src', 'dest', 'host', 'scheme') - def __init__(self): - self.src = None - self.dest = None - self.host = None - self.scheme = None + def __init__(self) -> None: + self.src: Optional[str] = None + self.dest: Optional[str] = None + self.host: Optional[str] = None + self.scheme: Optional[str] = None -def _parse_forwarded_header(forwarded): +def _parse_forwarded_header(forwarded: str) -> List[Forwarded]: """Parse the value of a Forwarded header. Makes an effort to parse Forwarded headers as specified by RFC 7239: diff --git a/falcon/hooks.py b/falcon/hooks.py index f4bb7afae..5ca50aefb 100644 --- a/falcon/hooks.py +++ b/falcon/hooks.py @@ -14,21 +14,35 @@ """Hook decorators.""" +from __future__ import annotations + from functools import wraps from inspect import getmembers from inspect import iscoroutinefunction import re +import typing as t from falcon.constants import COMBINED_METHODS from falcon.util.misc import get_argnames from falcon.util.sync import _wrap_non_coroutine_unsafe +if t.TYPE_CHECKING: # pragma: no cover + import falcon as wsgi + from falcon import asgi + _DECORABLE_METHOD_NAME = re.compile( r'^on_({})(_\w+)?$'.format('|'.join(method.lower() for method in COMBINED_METHODS)) ) +Resource = object +Responder = t.Callable +ResponderOrResource = t.Union[Responder, Resource] +Action = t.Callable -def before(action, *args, is_async=False, **kwargs): + +def before( + action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any +) -> t.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *before* the responder. The `params` argument that is passed to the hook @@ -78,7 +92,7 @@ def do_something(req, resp, resource, params): *action*. """ - def _before(responder_or_resource): + def _before(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): resource = responder_or_resource @@ -88,7 +102,9 @@ def _before(responder_or_resource): # responder in the do_before_all closure; otherwise, they # will capture the same responder variable that is shared # between iterations of the for loop, above. - def let(responder=responder): + responder = t.cast(Responder, responder) + + def let(responder: Responder = responder) -> None: do_before_all = _wrap_with_before( responder, action, args, kwargs, is_async ) @@ -100,7 +116,7 @@ def let(responder=responder): return resource else: - responder = responder_or_resource + responder = t.cast(Responder, responder_or_resource) do_before_one = _wrap_with_before(responder, action, args, kwargs, is_async) return do_before_one @@ -108,7 +124,9 @@ def let(responder=responder): return _before -def after(action, *args, is_async=False, **kwargs): +def after( + action: Action, *args: t.Any, is_async: bool = False, **kwargs: t.Any +) -> t.Callable[[ResponderOrResource], ResponderOrResource]: """Execute the given action function *after* the responder. Args: @@ -141,14 +159,15 @@ def after(action, *args, is_async=False, **kwargs): *action*. """ - def _after(responder_or_resource): + def _after(responder_or_resource: ResponderOrResource) -> ResponderOrResource: if isinstance(responder_or_resource, type): - resource = responder_or_resource + resource = t.cast(Resource, responder_or_resource) for responder_name, responder in getmembers(resource, callable): if _DECORABLE_METHOD_NAME.match(responder_name): + responder = t.cast(Responder, responder) - def let(responder=responder): + def let(responder: Responder = responder) -> None: do_after_all = _wrap_with_after( responder, action, args, kwargs, is_async ) @@ -160,7 +179,7 @@ def let(responder=responder): return resource else: - responder = responder_or_resource + responder = t.cast(Responder, responder_or_resource) do_after_one = _wrap_with_after(responder, action, args, kwargs, is_async) return do_after_one @@ -173,7 +192,13 @@ def let(responder=responder): # ----------------------------------------------------------------------------- -def _wrap_with_after(responder, action, action_args, action_kwargs, is_async): +def _wrap_with_after( + responder: Responder, + action: Action, + action_args: t.Any, + action_kwargs: t.Any, + is_async: bool, +) -> Responder: """Execute the given action function after a responder method. Args: @@ -196,20 +221,35 @@ def _wrap_with_after(responder, action, action_args, action_kwargs, is_async): # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - action = _wrap_non_coroutine_unsafe(action) + async_action = _wrap_non_coroutine_unsafe(action) + else: + async_action = action @wraps(responder) - async def do_after(self, req, resp, *args, **kwargs): + async def do_after( + self: ResponderOrResource, + req: asgi.Request, + resp: asgi.Response, + *args: t.Any, + **kwargs: t.Any, + ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) await responder(self, req, resp, **kwargs) - await action(req, resp, self, *action_args, **action_kwargs) + assert async_action + await async_action(req, resp, self, *action_args, **action_kwargs) else: @wraps(responder) - def do_after(self, req, resp, *args, **kwargs): + def do_after( + self: ResponderOrResource, + req: wsgi.Request, + resp: wsgi.Response, + *args: t.Any, + **kwargs: t.Any, + ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) @@ -219,7 +259,13 @@ def do_after(self, req, resp, *args, **kwargs): return do_after -def _wrap_with_before(responder, action, action_args, action_kwargs, is_async): +def _wrap_with_before( + responder: Responder, + action: Action, + action_args: t.Tuple[t.Any, ...], + action_kwargs: t.Dict[str, t.Any], + is_async: bool, +) -> t.Union[t.Callable[..., t.Awaitable[None]], t.Callable[..., None]]: """Execute the given action function before a responder method. Args: @@ -242,20 +288,35 @@ def _wrap_with_before(responder, action, action_args, action_kwargs, is_async): # is actually covered, but coverage isn't tracking it for # some reason. if not is_async: # pragma: nocover - action = _wrap_non_coroutine_unsafe(action) + async_action = _wrap_non_coroutine_unsafe(action) + else: + async_action = action @wraps(responder) - async def do_before(self, req, resp, *args, **kwargs): + async def do_before( + self: ResponderOrResource, + req: asgi.Request, + resp: asgi.Response, + *args: t.Any, + **kwargs: t.Any, + ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) - await action(req, resp, self, kwargs, *action_args, **action_kwargs) + assert async_action + await async_action(req, resp, self, kwargs, *action_args, **action_kwargs) await responder(self, req, resp, **kwargs) else: @wraps(responder) - def do_before(self, req, resp, *args, **kwargs): + def do_before( + self: ResponderOrResource, + req: wsgi.Request, + resp: wsgi.Response, + *args: t.Any, + **kwargs: t.Any, + ) -> None: if args: _merge_responder_args(args, kwargs, extra_argnames) @@ -265,7 +326,9 @@ def do_before(self, req, resp, *args, **kwargs): return do_before -def _merge_responder_args(args, kwargs, argnames): +def _merge_responder_args( + args: t.Tuple[t.Any, ...], kwargs: t.Dict[str, t.Any], argnames: t.List[str] +) -> None: """Merge responder args into kwargs. The framework always passes extra args as keyword arguments. diff --git a/falcon/http_error.py b/falcon/http_error.py index 20c221fe3..3f4af635c 100644 --- a/falcon/http_error.py +++ b/falcon/http_error.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """HTTPError exception class.""" +from __future__ import annotations + from collections import OrderedDict +from typing import MutableMapping, Optional, Type, TYPE_CHECKING, Union import xml.etree.ElementTree as et from falcon.constants import MEDIA_JSON @@ -23,6 +25,12 @@ from falcon.util import uri from falcon.util.deprecation import deprecated_args +if TYPE_CHECKING: + from falcon.media import BaseHandler + from falcon.typing import HeaderList + from falcon.typing import Link + from falcon.typing import ResponseStatus + class HTTPError(Exception): """Represents a generic HTTP error. @@ -112,13 +120,13 @@ class HTTPError(Exception): @deprecated_args(allowed_positional=1) def __init__( self, - status, - title=None, - description=None, - headers=None, - href=None, - href_text=None, - code=None, + status: ResponseStatus, + title: Optional[str] = None, + description: Optional[str] = None, + headers: Optional[HeaderList] = None, + href: Optional[str] = None, + href_text: Optional[str] = None, + code: Optional[int] = None, ): self.status = status @@ -131,6 +139,7 @@ def __init__( self.description = description self.headers = headers self.code = code + self.link: Optional[Link] if href: link = self.link = OrderedDict() @@ -140,7 +149,7 @@ def __init__( else: self.link = None - def __repr__(self): + def __repr__(self) -> str: return '<%s: %s>' % (self.__class__.__name__, self.status) __str__ = __repr__ @@ -149,7 +158,9 @@ def __repr__(self): def status_code(self) -> int: return http_status_to_code(self.status) - def to_dict(self, obj_type=dict): + def to_dict( + self, obj_type: Type[MutableMapping[str, Union[str, int, None, Link]]] = dict + ) -> MutableMapping[str, Union[str, int, None, Link]]: """Return a basic dictionary representing the error. This method can be useful when serializing the error to hash-like @@ -180,7 +191,7 @@ def to_dict(self, obj_type=dict): return obj - def to_json(self, handler=None): + def to_json(self, handler: Optional[BaseHandler] = None) -> bytes: """Return a JSON representation of the error. Args: @@ -198,7 +209,7 @@ def to_json(self, handler=None): handler = _DEFAULT_JSON_HANDLER return handler.serialize(obj, MEDIA_JSON) - def to_xml(self): + def to_xml(self) -> bytes: """Return an XML-encoded representation of the error. Returns: @@ -229,4 +240,7 @@ def to_xml(self): # NOTE: initialized in falcon.media.json, that is always imported since Request/Response # are imported by falcon init. -_DEFAULT_JSON_HANDLER = None +if TYPE_CHECKING: + _DEFAULT_JSON_HANDLER: BaseHandler +else: + _DEFAULT_JSON_HANDLER = None diff --git a/falcon/http_status.py b/falcon/http_status.py index d4411391e..df7e0d455 100644 --- a/falcon/http_status.py +++ b/falcon/http_status.py @@ -11,12 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """HTTPStatus exception class.""" +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + from falcon.util import http_status_to_code from falcon.util.deprecation import AttributeRemovedError +if TYPE_CHECKING: + from falcon.typing import HeaderList + from falcon.typing import ResponseStatus + class HTTPStatus(Exception): """Represents a generic HTTP status. @@ -46,7 +53,12 @@ class HTTPStatus(Exception): __slots__ = ('status', 'headers', 'text') - def __init__(self, status, headers=None, text=None): + def __init__( + self, + status: ResponseStatus, + headers: Optional[HeaderList] = None, + text: Optional[str] = None, + ) -> None: self.status = status self.headers = headers self.text = text diff --git a/falcon/inspect.py b/falcon/inspect.py index 919165687..9aac44cb0 100644 --- a/falcon/inspect.py +++ b/falcon/inspect.py @@ -14,16 +14,30 @@ """Inspect utilities for falcon applications.""" +from __future__ import annotations + from functools import partial import inspect -from typing import Callable, Dict, List, Optional, Type +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + Union, +) from falcon import app_helpers from falcon.app import App from falcon.routing import CompiledRouter +from falcon.routing.compiled import CompiledRouterNode -def inspect_app(app: App) -> 'AppInfo': +def inspect_app(app: App) -> AppInfo: """Inspects an application. Args: @@ -43,7 +57,7 @@ def inspect_app(app: App) -> 'AppInfo': return AppInfo(routes, middleware, static, sinks, error_handlers, app._ASGI) -def inspect_routes(app: App) -> 'List[RouteInfo]': +def inspect_routes(app: App) -> List[RouteInfo]: """Inspects the routes of an application. Args: @@ -65,7 +79,9 @@ def inspect_routes(app: App) -> 'List[RouteInfo]': return inspect_function(router) -def register_router(router_class): +def register_router( + router_class: Type, +) -> Callable[..., Callable[..., List[RouteInfo]]]: """Register a function to inspect a particular router. This decorator registers a new function for a custom router @@ -83,7 +99,7 @@ def inspect_my_router(router): already registered an error will be raised. """ - def wraps(fn): + def wraps(fn: Callable[..., List[RouteInfo]]) -> Callable[..., List[RouteInfo]]: if router_class in _supported_routers: raise ValueError( 'Another function is already registered for the router {}'.format( @@ -96,8 +112,7 @@ def wraps(fn): return wraps -# router inspection registry -_supported_routers: Dict[Type, Callable] = {} +_supported_routers: Dict[Type, Callable[..., Any]] = {} def inspect_static_routes(app: App) -> 'List[StaticRouteInfo]': @@ -131,6 +146,7 @@ def inspect_sinks(app: App) -> 'List[SinkInfo]': sinks = [] for prefix, sink, _ in app._sinks: source_info, name = _get_source_info_and_name(sink) + assert source_info info = SinkInfo(prefix.pattern, name, source_info) sinks.append(info) return sinks @@ -150,6 +166,7 @@ def inspect_error_handlers(app: App) -> 'List[ErrorHandlerInfo]': errors = [] for exc, fn in app._error_handlers.items(): source_info, name = _get_source_info_and_name(fn) + assert source_info info = ErrorHandlerInfo(exc.__name__, name, source_info, _is_internal(fn)) errors.append(info) return errors @@ -188,7 +205,9 @@ def inspect_middleware(app: App) -> 'MiddlewareInfo': if method: real_func = method[0] source_info = _get_source_info(real_func) + assert source_info methods.append(MiddlewareMethodInfo(real_func.__name__, source_info)) + assert class_source_info m_info = MiddlewareClassInfo(cls_name, class_source_info, methods) middlewareClasses.append(m_info) @@ -210,7 +229,7 @@ def inspect_compiled_router(router: CompiledRouter) -> 'List[RouteInfo]': List[RouteInfo]: A list of :class:`~.RouteInfo`. """ - def _traverse(roots, parent): + def _traverse(roots: List[CompiledRouterNode], parent: str) -> None: for root in roots: path = parent + '/' + root.raw_segment if root.resource is not None: @@ -224,13 +243,16 @@ def _traverse(roots, parent): source_info = _get_source_info(real_func) internal = _is_internal(real_func) - + assert source_info, ( + 'This is for type checking only, as here source ' + 'info will always be a string' + ) method_info = RouteMethodInfo( method, source_info, real_func.__name__, internal ) methods.append(method_info) source_info, class_name = _get_source_info_and_name(root.resource) - + assert source_info route_info = RouteInfo(path, class_name, source_info, methods) routes.append(route_info) @@ -250,7 +272,7 @@ def _traverse(roots, parent): class _Traversable: __visit_name__ = 'N/A' - def to_string(self, verbose=False, internal=False) -> str: + def to_string(self, verbose: bool = False, internal: bool = False) -> str: """Return a string representation of this class. Args: @@ -264,7 +286,7 @@ def to_string(self, verbose=False, internal=False) -> str: """ return StringVisitor(verbose, internal).process(self) - def __repr__(self): + def __repr__(self) -> str: return self.to_string() @@ -520,7 +542,9 @@ def __init__( self.error_handlers = error_handlers self.asgi = asgi - def to_string(self, verbose=False, internal=False, name='') -> str: + def to_string( + self, verbose: bool = False, internal: bool = False, name: str = '' + ) -> str: """Return a string representation of this class. Args: @@ -546,7 +570,7 @@ class InspectVisitor: Subclasses must implement ``visit_`` methods for each supported class. """ - def process(self, instance: _Traversable): + def process(self, instance: _Traversable) -> str: """Process the instance, by calling the appropriate visit method. Uses the `__visit_name__` attribute of the `instance` to obtain the method @@ -577,14 +601,16 @@ class StringVisitor(InspectVisitor): beginning of the text. Defaults to ``'Falcon App'``. """ - def __init__(self, verbose=False, internal=False, name=''): + def __init__( + self, verbose: bool = False, internal: bool = False, name: str = '' + ) -> None: self.verbose = verbose self.internal = internal self.name = name self.indent = 0 @property - def tab(self): + def tab(self) -> str: """Get the current tabulation.""" return ' ' * self.indent @@ -595,13 +621,15 @@ def visit_route_method(self, route_method: RouteMethodInfo) -> str: text += ' ({0.source_info})'.format(route_method) return text - def _methods_to_string(self, methods: List): + def _methods_to_string( + self, methods: Union[List[RouteMethodInfo], List[MiddlewareMethodInfo]] + ) -> str: """Return a string from the list of methods.""" tab = self.tab + ' ' * 3 - methods = _filter_internal(methods, self.internal) - if not methods: + filtered_methods = _filter_internal(methods, self.internal) + if not filtered_methods: return '' - text_list = [self.process(m) for m in methods] + text_list = [self.process(m) for m in filtered_methods] method_text = ['{}├── {}'.format(tab, m) for m in text_list[:-1]] method_text += ['{}└── {}'.format(tab, m) for m in text_list[-1:]] return '\n'.join(method_text) @@ -751,7 +779,9 @@ def visit_app(self, app: AppInfo) -> str: # ------------------------------------------------------------------------ -def _get_source_info(obj, default='[unknown file]'): +def _get_source_info( + obj: Any, default: Optional[str] = '[unknown file]' +) -> Optional[str]: """Try to get the definition file and line of obj. Return default on error. @@ -765,11 +795,11 @@ def _get_source_info(obj, default='[unknown file]'): # responders coming from cythonized modules will # appear as built-in functions, and raise a # TypeError when trying to locate the source file. - source_info = default + return default return source_info -def _get_source_info_and_name(obj): +def _get_source_info_and_name(obj: Any) -> Tuple[Optional[str], str]: """Attempt to get the definition file and line of obj and its name.""" source_info = _get_source_info(obj, None) if source_info is None: @@ -778,10 +808,11 @@ def _get_source_info_and_name(obj): name = getattr(obj, '__name__', None) if name is None: name = getattr(type(obj), '__name__', '[unknown]') + name = cast(str, name) return source_info, name -def _is_internal(obj): +def _is_internal(obj: Any) -> bool: """Check if the module of the object is a falcon module.""" module = inspect.getmodule(obj) if module: @@ -789,7 +820,14 @@ def _is_internal(obj): return False -def _filter_internal(iterable, return_internal): +def _filter_internal( + iterable: Union[ + Iterable[RouteMethodInfo], + Iterable[ErrorHandlerInfo], + Iterable[MiddlewareMethodInfo], + ], + return_internal: bool, +) -> Union[Iterable[_Traversable], List[_Traversable]]: """Filter the internal elements of an iterable.""" if return_internal: return iterable diff --git a/falcon/media/handlers.py b/falcon/media/handlers.py index 0186e0aee..1c8ad0b48 100644 --- a/falcon/media/handlers.py +++ b/falcon/media/handlers.py @@ -1,11 +1,13 @@ from collections import UserDict import functools +from typing import Mapping from falcon import errors from falcon.constants import MEDIA_JSON from falcon.constants import MEDIA_MULTIPART from falcon.constants import MEDIA_URLENCODED from falcon.constants import PYPY +from falcon.media.base import BaseHandler from falcon.media.base import BinaryBaseHandlerWS from falcon.media.json import JSONHandler from falcon.media.multipart import MultipartFormHandler @@ -41,7 +43,7 @@ class Handlers(UserDict): def __init__(self, initial=None): self._resolve = self._create_resolver() - handlers = initial or { + handlers: Mapping[str, BaseHandler] = initial or { MEDIA_JSON: JSONHandler(), MEDIA_MULTIPART: MultipartFormHandler(), MEDIA_URLENCODED: URLEncodedFormHandler(), diff --git a/falcon/middleware.py b/falcon/middleware.py index a799fe2f8..5772e16c7 100644 --- a/falcon/middleware.py +++ b/falcon/middleware.py @@ -1,4 +1,6 @@ -from typing import Iterable, Optional, Union +from __future__ import annotations + +from typing import Any, Iterable, Optional, Union from .request import Request from .response import Response @@ -73,7 +75,9 @@ def __init__( ) self.allow_credentials = allow_credentials - def process_response(self, req: Request, resp: Response, resource, req_succeeded): + def process_response( + self, req: Request, resp: Response, resource: object, req_succeeded: bool + ) -> None: """Implement the CORS policy for all routes. This middleware provides a simple out-of-the box CORS policy, @@ -120,5 +124,5 @@ def process_response(self, req: Request, resp: Response, resource, req_succeeded resp.set_header('Access-Control-Allow-Headers', allow_headers) resp.set_header('Access-Control-Max-Age', '86400') # 24 hours - async def process_response_async(self, *args): + async def process_response_async(self, *args: Any) -> None: self.process_response(*args) diff --git a/falcon/redirects.py b/falcon/redirects.py index e308a5064..7d2381d47 100644 --- a/falcon/redirects.py +++ b/falcon/redirects.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """HTTPStatus specializations for 3xx redirects.""" +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + import falcon from falcon.http_status import HTTPStatus +if TYPE_CHECKING: + from falcon.typing import Headers + class HTTPMovedPermanently(HTTPStatus): """301 Moved Permanently. @@ -37,7 +43,7 @@ class HTTPMovedPermanently(HTTPStatus): response. """ - def __init__(self, location, headers=None): + def __init__(self, location: str, headers: Optional[Headers] = None) -> None: if headers is None: headers = {} headers.setdefault('location', location) @@ -66,7 +72,7 @@ class HTTPFound(HTTPStatus): response. """ - def __init__(self, location, headers=None): + def __init__(self, location: str, headers: Optional[Headers] = None) -> None: if headers is None: headers = {} headers.setdefault('location', location) @@ -100,7 +106,7 @@ class HTTPSeeOther(HTTPStatus): response. """ - def __init__(self, location, headers=None): + def __init__(self, location: str, headers: Optional[Headers] = None) -> None: if headers is None: headers = {} headers.setdefault('location', location) @@ -129,7 +135,7 @@ class HTTPTemporaryRedirect(HTTPStatus): response. """ - def __init__(self, location, headers=None): + def __init__(self, location: str, headers: Optional[Headers] = None) -> None: if headers is None: headers = {} headers.setdefault('location', location) @@ -155,7 +161,7 @@ class HTTPPermanentRedirect(HTTPStatus): response. """ - def __init__(self, location, headers=None): + def __init__(self, location: str, headers: Optional[Headers] = None) -> None: if headers is None: headers = {} headers.setdefault('location', location) diff --git a/falcon/request.py b/falcon/request.py index cb037369d..f8fc6f4ab 100644 --- a/falcon/request.py +++ b/falcon/request.py @@ -12,6 +12,8 @@ """Request class.""" +from __future__ import annotations + from datetime import datetime from io import BytesIO from uuid import UUID diff --git a/falcon/response.py b/falcon/response.py index 05f6c2e10..e96f2ba2f 100644 --- a/falcon/response.py +++ b/falcon/response.py @@ -14,9 +14,11 @@ """Response class.""" +from __future__ import annotations + import functools import mimetypes -from typing import Optional +from typing import Dict, Optional from falcon.constants import _DEFAULT_STATIC_MEDIA_TYPES from falcon.constants import _UNSET @@ -1252,7 +1254,7 @@ class ResponseOptions: secure_cookies_by_default: bool default_media_type: Optional[str] media_handlers: Handlers - static_media_types: dict + static_media_types: Dict[str, str] __slots__ = ( 'secure_cookies_by_default', diff --git a/falcon/typing.py b/falcon/typing.py index bc4137027..a7095bb56 100644 --- a/falcon/typing.py +++ b/falcon/typing.py @@ -11,19 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Shorthand definitions for more complex types.""" -from typing import Any, Callable, Pattern, Union +from __future__ import annotations + +import http +from typing import ( + Any, + Callable, + Dict, + List, + Pattern, + Tuple, + TYPE_CHECKING, + Union, +) + +if TYPE_CHECKING: + from falcon.request import Request + from falcon.response import Response + -from falcon.request import Request -from falcon.response import Response +Link = Dict[str, str] # Error handlers -ErrorHandler = Callable[[Request, Response, BaseException, dict], Any] +ErrorHandler = Callable[['Request', 'Response', BaseException, dict], Any] # Error serializers -ErrorSerializer = Callable[[Request, Response, BaseException], Any] +ErrorSerializer = Callable[['Request', 'Response', BaseException], Any] # Sinks SinkPrefix = Union[str, Pattern] @@ -33,3 +48,6 @@ # arguments afterwords? # class SinkCallable(Protocol): # def __call__(sef, req: Request, resp: Response, ): ... +Headers = Dict[str, str] +HeaderList = Union[Headers, List[Tuple[str, str]]] +ResponseStatus = Union[http.HTTPStatus, str, int] diff --git a/pyproject.toml b/pyproject.toml index c56da1c31..559256ad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,10 @@ ] [tool.mypy] - exclude = "falcon/bench/|falcon/cmd/" + exclude = [ + "falcon/bench/|falcon/cmd/", + "falcon/vendor" + ] [[tool.mypy.overrides]] module = [ "cbor2", @@ -35,8 +38,20 @@ [[tool.mypy.overrides]] module = [ + "falcon.util.*", + "falcon.app_helpers", + "falcon.asgi_spec", + "falcon.constants", + "falcon.errors", + "falcon.forwarded", + "falcon.hooks", + "falcon.http_error", + "falcon.http_status", + "falcon.http_status", + "falcon.inspect", + "falcon.middleware", + "falcon.redirects", "falcon.stream", - "falcon.util.*" ] disallow_untyped_defs = true