From 869b9b8971030413a1f3fa0ce29758e676b03c26 Mon Sep 17 00:00:00 2001 From: Aldo Mateli Date: Sat, 22 Jun 2024 13:32:49 +0100 Subject: [PATCH] Avoid interacting with class/fn objects during registration. --- test/integration/flask/__init__.py | 0 test/integration/flask/services/__init__.py | 0 test/integration/flask/services/factories.py | 12 ++++++++++++ .../{ => flask}/test_flask_integration.py | 19 +++++++++++++++++-- wireup/import_util.py | 9 ++++++++- 5 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 test/integration/flask/__init__.py create mode 100644 test/integration/flask/services/__init__.py create mode 100644 test/integration/flask/services/factories.py rename test/integration/{ => flask}/test_flask_integration.py (84%) diff --git a/test/integration/flask/__init__.py b/test/integration/flask/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/integration/flask/services/__init__.py b/test/integration/flask/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/integration/flask/services/factories.py b/test/integration/flask/services/factories.py new file mode 100644 index 0000000..0676410 --- /dev/null +++ b/test/integration/flask/services/factories.py @@ -0,0 +1,12 @@ +from flask import g +from wireup import service + + +class FlaskG: + def __init__(self, g): + self.g = g + + +@service +def thing() -> FlaskG: + return FlaskG(g) diff --git a/test/integration/test_flask_integration.py b/test/integration/flask/test_flask_integration.py similarity index 84% rename from test/integration/test_flask_integration.py rename to test/integration/flask/test_flask_integration.py index ef8d624..a8e7c9a 100644 --- a/test/integration/test_flask_integration.py +++ b/test/integration/flask/test_flask_integration.py @@ -1,9 +1,11 @@ import unittest from dataclasses import dataclass from test.fixtures import FooBar, FooBase +from test.integration.flask import services +from test.integration.flask.services.factories import FlaskG from test.unit.services.no_annotations.random.random_service import RandomService -from flask import Flask +from flask import Flask, g from typing_extensions import Annotated from wireup import DependencyContainer, Inject, ParameterBag from wireup.integration.flask_integration import wireup_init_flask_integration @@ -69,7 +71,7 @@ def target(foo: FooBase): self.container.abstract(FooBase) self.container.register(FooBar) - wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[]) + wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[services]) res = self.client.get("/intf") self.assertEqual(res.status_code, 200) @@ -90,3 +92,16 @@ def get_environment(foo: Foo): self.assertEqual(res.status_code, 200) self.assertEqual(res.json, {"test": True}) + + def test_does_not_interact_with_flask_g_when_registering(self): + @self.app.get("/") + def get_environment(foo: FlaskG): + foo.g.value = 1 + + return {"g": g.value} + + wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[services]) + res = self.client.get("/") + + self.assertEqual(res.status_code, 200) + self.assertEqual(res.json, {"g": 1}) diff --git a/wireup/import_util.py b/wireup/import_util.py index 438f642..1d8411d 100644 --- a/wireup/import_util.py +++ b/wireup/import_util.py @@ -4,6 +4,7 @@ import importlib import inspect import re +import types import warnings from pathlib import Path from typing import TYPE_CHECKING, Any @@ -44,8 +45,14 @@ def _register_services(dependency_container: DependencyContainer, service_module abstract_registrations: set[type[Any]] = set() service_registrations: list[ServiceDeclaration] = [] + def _is_valid_wireup_target(obj: Any) -> bool: + # Check that the hasattr call is only made on user defined functions and classes. + # This is so that it avoids interacting with proxies and things such as flask.g when imported. + # "from flask import g" would cause a hasattr call to g outside of app context. + return (isinstance(obj, types.FunctionType) or inspect.isclass(obj)) and hasattr(obj, "__wireup_registration__") + for module in service_modules: - for cls in _find_objects_in_module(module, predicate=lambda obj: hasattr(obj, "__wireup_registration__")): + for cls in _find_objects_in_module(module, predicate=_is_valid_wireup_target): reg = getattr(cls, "__wireup_registration__", None) if isinstance(reg, ServiceDeclaration):