Skip to content

Commit

Permalink
Avoid interacting with class/fn objects during registration.
Browse files Browse the repository at this point in the history
  • Loading branch information
maldoinc committed Jun 22, 2024
1 parent b670616 commit 869b9b8
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 3 deletions.
Empty file.
Empty file.
12 changes: 12 additions & 0 deletions test/integration/flask/services/factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from flask import g
from wireup import service


class FlaskG:
def __init__(self, g):
self.g = g


@service
def thing() -> FlaskG:
return FlaskG(g)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest
from dataclasses import dataclass
from test.fixtures import FooBar, FooBase
from test.integration.flask import services
from test.integration.flask.services.factories import FlaskG
from test.unit.services.no_annotations.random.random_service import RandomService

from flask import Flask
from flask import Flask, g
from typing_extensions import Annotated
from wireup import DependencyContainer, Inject, ParameterBag
from wireup.integration.flask_integration import wireup_init_flask_integration
Expand Down Expand Up @@ -69,7 +71,7 @@ def target(foo: FooBase):

self.container.abstract(FooBase)
self.container.register(FooBar)
wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[])
wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[services])

res = self.client.get("/intf")
self.assertEqual(res.status_code, 200)
Expand All @@ -90,3 +92,16 @@ def get_environment(foo: Foo):

self.assertEqual(res.status_code, 200)
self.assertEqual(res.json, {"test": True})

def test_does_not_interact_with_flask_g_when_registering(self):
@self.app.get("/")
def get_environment(foo: FlaskG):
foo.g.value = 1

return {"g": g.value}

wireup_init_flask_integration(self.app, dependency_container=self.container, service_modules=[services])
res = self.client.get("/")

self.assertEqual(res.status_code, 200)
self.assertEqual(res.json, {"g": 1})
9 changes: 8 additions & 1 deletion wireup/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import inspect
import re
import types
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -44,8 +45,14 @@ def _register_services(dependency_container: DependencyContainer, service_module
abstract_registrations: set[type[Any]] = set()
service_registrations: list[ServiceDeclaration] = []

def _is_valid_wireup_target(obj: Any) -> bool:
# Check that the hasattr call is only made on user defined functions and classes.
# This is so that it avoids interacting with proxies and things such as flask.g when imported.
# "from flask import g" would cause a hasattr call to g outside of app context.
return (isinstance(obj, types.FunctionType) or inspect.isclass(obj)) and hasattr(obj, "__wireup_registration__")

for module in service_modules:
for cls in _find_objects_in_module(module, predicate=lambda obj: hasattr(obj, "__wireup_registration__")):
for cls in _find_objects_in_module(module, predicate=_is_valid_wireup_target):
reg = getattr(cls, "__wireup_registration__", None)

if isinstance(reg, ServiceDeclaration):
Expand Down

0 comments on commit 869b9b8

Please sign in to comment.