Skip to content

Commit

Permalink
Added forward reference policies
Browse files Browse the repository at this point in the history
Fixes #70.
  • Loading branch information
agronholm committed Aug 11, 2019
1 parent 6cfa1f9 commit 309c104
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
- Fixed bogus ``TypeError`` on ``Type[Any]``
- Fixed bogus ``TypeChecker`` warnings when an exception is raised from a type checked function
- Accept a ``bytearray`` where ``bytes`` are expected, as per `python/typing#552`_
- Added policies for dealing with unmatched forward references

.. _python/typing#552: https://github.com/python/typing/issues/552

Expand Down
33 changes: 32 additions & 1 deletion tests/test_typeguard.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import sys
from concurrent.futures import ThreadPoolExecutor
from functools import wraps, partial
Expand All @@ -10,7 +11,7 @@

from typeguard import (
typechecked, check_argument_types, qualified_name, TypeChecker, TypeWarning, function_name,
check_type, Literal)
check_type, Literal, TypeHintWarning, ForwardRefPolicy)

try:
from typing import Type
Expand Down Expand Up @@ -990,3 +991,33 @@ def test_exception(self, checker):
pytest.raises(ZeroDivisionError, self.error_function)

assert len(record) == 0

@pytest.mark.parametrize('policy', [ForwardRefPolicy.WARN, ForwardRefPolicy.GUESS],
ids=['warn', 'guess'])
def test_forward_ref_policy_resolution_fails(self, checker, policy):
def unresolvable_annotation(x: 'OrderedDict'): # noqa
pass

checker.annotation_policy = policy
gc.collect() # prevent find_function() from finding more than one instance of the function
with checker, pytest.warns(TypeHintWarning) as record:
unresolvable_annotation({})

assert len(record) == 1
assert ("unresolvable_annotation: name 'OrderedDict' is not defined"
in str(record[0].message))
assert 'x' not in unresolvable_annotation.__annotations__

def test_forward_ref_policy_guess(self, checker):
import collections

def unresolvable_annotation(x: 'OrderedDict'): # noqa
pass

checker.annotation_policy = ForwardRefPolicy.GUESS
with checker, pytest.warns(TypeHintWarning) as record:
unresolvable_annotation(collections.OrderedDict())

assert len(record) == 1
assert str(record[0].message).startswith("Replaced forward declaration 'OrderedDict' in")
assert unresolvable_annotation.__annotations__['x'] is collections.OrderedDict
65 changes: 60 additions & 5 deletions typeguard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
__all__ = ('typechecked', 'check_argument_types', 'check_type', 'TypeWarning', 'TypeChecker')
__all__ = ('ForwardRefPolicy', 'TypeHintWarning', 'typechecked', 'check_argument_types',
'check_type', 'TypeWarning', 'TypeChecker')

import collections.abc
import gc
import inspect
import sys
import threading
from collections import OrderedDict
from enum import Enum
from functools import wraps, partial
from inspect import Parameter, isclass, isfunction, isgeneratorfunction
from io import TextIOBase, RawIOBase, IOBase, BufferedIOBase
Expand Down Expand Up @@ -40,12 +42,29 @@ def isasyncgenfunction(func):
T_Callable = TypeVar('T_Callable', bound=Callable)


class ForwardRefPolicy(Enum):
"""Defines how unresolved forward references are handled."""

ERROR = 1 #: propagate the NameError from get_type_hints()
WARN = 2 #: remove the annotation and emit a TypeHintWarning
#: replace the annotation with the argument's class if the qualified name matches, else remove
#: the annotation
GUESS = 3


class TypeHintWarning(UserWarning):
"""
A warning that is emitted when a type hint in string form could not be resolved to an actual
type.
"""


class _CallMemo:
__slots__ = ('func', 'func_name', 'signature', 'typevars', 'arguments', 'type_hints',
'is_generator')

def __init__(self, func: Callable, frame=None, args: tuple = None,
kwargs: Dict[str, Any] = None):
kwargs: Dict[str, Any] = None, forward_refs_policy=ForwardRefPolicy.ERROR):
self.func = func
self.func_name = function_name(func)
self.signature = inspect.signature(func)
Expand All @@ -60,7 +79,38 @@ def __init__(self, func: Callable, frame=None, args: tuple = None,

self.type_hints = _type_hints_map.get(func)
if self.type_hints is None:
hints = get_type_hints(func)
while True:
try:
hints = get_type_hints(func)
except NameError as exc:
if forward_refs_policy is ForwardRefPolicy.ERROR:
raise

typename = str(exc).split("'", 2)[1]
for param in self.signature.parameters.values():
if param.annotation == typename:
break
else:
raise

func_name = function_name(func)
if forward_refs_policy is ForwardRefPolicy.GUESS:
if param.name in self.arguments:
argtype = self.arguments[param.name].__class__
if param.annotation == argtype.__qualname__:
func.__annotations__[param.name] = argtype
msg = ('Replaced forward declaration {!r} in {} with {!r}'
.format(param.annotation, func_name, argtype))
warn(TypeHintWarning(msg))
continue

msg = 'Could not resolve type hint {!r} on {}: {}'.format(
param.annotation, function_name(func), exc)
warn(TypeHintWarning(msg))
del func.__annotations__[param.name]
else:
break

self.type_hints = OrderedDict()
for name, parameter in self.signature.parameters.items():
if name in hints:
Expand Down Expand Up @@ -617,13 +667,17 @@ class TypeChecker:
"""
A type checker that collects type violations by hooking into ``sys.setprofile()``.
:param packages: list of top level modules and packages or modules to include for type checking
:param all_threads: ``True`` to check types in all threads created while the checker is
running, ``False`` to only check in the current one
:param forward_refs_policy: how to handle unresolvable forward references in annotations
"""

def __init__(self, packages: Union[str, Sequence[str]], *, all_threads: bool = True):
def __init__(self, packages: Union[str, Sequence[str]], *, all_threads: bool = True,
forward_refs_policy: ForwardRefPolicy = ForwardRefPolicy.ERROR):
assert check_argument_types()
self.all_threads = all_threads
self.annotation_policy = forward_refs_policy
self._call_memos = {} # type: Dict[Any, _CallMemo]
self._previous_profiler = None
self._previous_thread_profiler = None
Expand Down Expand Up @@ -706,7 +760,8 @@ def __call__(self, frame, event: str, arg) -> None: # pragma: no cover
func = None

if func is not None and self.should_check_type(func):
memo = self._call_memos[frame] = _CallMemo(func, frame)
memo = self._call_memos[frame] = _CallMemo(
func, frame, forward_refs_policy=self.annotation_policy)
if memo.is_generator:
return_type_hint = memo.type_hints['return']
if return_type_hint is not None:
Expand Down

0 comments on commit 309c104

Please sign in to comment.