diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f51dfb..0572c6e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -76,3 +76,5 @@ repos: rev: v1.3.0 hooks: - id: mypy + additional_dependencies: + - django-stubs==1.14.0 diff --git a/pyproject.toml b/pyproject.toml index 2120520..75a69ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,16 +23,11 @@ addopts = """\ django_find_project = false [tool.mypy] -check_untyped_defs = true -disallow_any_generics = true -disallow_incomplete_defs = true -disallow_untyped_defs = true mypy_path = "src/" namespace_packages = false -no_implicit_optional = true show_error_codes = true +strict = true warn_unreachable = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "tests.*" diff --git a/src/auto_prefetch/__init__.py b/src/auto_prefetch/__init__.py index 32bb279..9bb02bb 100644 --- a/src/auto_prefetch/__init__.py +++ b/src/auto_prefetch/__init__.py @@ -80,7 +80,7 @@ class ReverseOneToOneDescriptor( def _is_cached(self, instance: models.Model) -> bool: return self.related.is_cached(instance) - def _field_name(self) -> str: + def _field_name(self) -> str | None: return self.related.get_accessor_name() @@ -113,6 +113,8 @@ def _fetch_all(self) -> None: class Model(models.Model): + _peers: WeakValueDictionary[str, Model] + class Meta: abstract = True base_manager_name = "prefetch_manager" @@ -131,8 +133,8 @@ def __getstate__(self) -> dict[str, Any]: return res @classmethod - def check(cls, **kwargs: Any) -> list[checks.Error]: - errors: list[checks.Error] = super().check(**kwargs) + def check(cls, **kwargs: Any) -> list[checks.CheckMessage]: + errors: list[checks.CheckMessage] = super().check(**kwargs) errors.extend(cls._check_meta_inheritance()) return errors diff --git a/tests/test_basic.py b/tests/test_basic.py index c955894..e040df3 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -204,7 +204,7 @@ def test_multiples(django_assert_num_queries, Model, queries): @pytest.mark.django_db def test_garbage_collection(): - def check_instances(num): + def check_instances(num: int) -> None: gc.collect() objs = [o for o in gc.get_objects() if isinstance(o, Prefetch)] assert len(objs) == num