From 3f2746cb26e31004eda0ffcfd323ad95076305a0 Mon Sep 17 00:00:00 2001 From: yakimka Date: Tue, 23 Apr 2024 21:19:47 +0300 Subject: [PATCH] Refactoring (#8) * Rename Depends to Provide * Use weakref * Remove ResolvedDependency * Rearrange code and add docstrings * Refactor resource decorator * More refactoring * More refactoring * Duplicated code; I don't know how to rewrite this in a nice way * Add init_resources * trying to refactor more [meh] * Add lock * Update README.md --- README.md | 25 ++- nanodi/__init__.py | 4 +- nanodi/nanodi.py | 340 +++++++++++++++++++++-------- nanodi/providers.py | 44 ---- nanodi/scopes.py | 33 +++ setup.cfg | 2 +- tests/test_complex_logic.py | 12 +- tests/test_sync_di.py | 20 +- tests/test_sync_di_with_closing.py | 6 +- tests/test_sync_resource.py | 16 +- 10 files changed, 315 insertions(+), 187 deletions(-) delete mode 100644 nanodi/providers.py create mode 100644 nanodi/scopes.py diff --git a/README.md b/README.md index dec3284..06cefb5 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,7 @@ Simple Dependency Injection for Python - -## Features - -- Add yours! - +Experimental dependency injection library for Python. Use it at your own risk. ## Installation @@ -19,16 +15,23 @@ Simple Dependency Injection for Python pip install nanodi ``` - ## Example -Showcase how your project can be used: - ```python -from nanodi.example import some_function +from nanodi import inject, Provide + + +def get_redis() -> str: + yield "redis" + print("closing redis") + + +@inject +def get_storage_service(redis: str = Provide(get_redis)) -> str: + return f"storage_service({redis})" + -print(some_function(3, 4)) -# => 7 +assert get_storage_service() == "storage_service(redis)" ``` ## License diff --git a/nanodi/__init__.py b/nanodi/__init__.py index d633b2d..fcd94e4 100644 --- a/nanodi/__init__.py +++ b/nanodi/__init__.py @@ -1,3 +1,3 @@ -from nanodi.nanodi import Depends, inject, resource, shutdown_resources +from nanodi.nanodi import Provide, init_resources, inject, resource, shutdown_resources -__all__ = ["Depends", "inject", "shutdown_resources", "resource"] +__all__ = ["Provide", "inject", "init_resources", "shutdown_resources", "resource"] diff --git a/nanodi/nanodi.py b/nanodi/nanodi.py index 7589903..13cca09 100644 --- a/nanodi/nanodi.py +++ b/nanodi/nanodi.py @@ -1,136 +1,282 @@ from __future__ import annotations +import asyncio import functools import inspect -from collections.abc import Callable, Coroutine, Generator +import threading +from collections.abc import Awaitable, Callable, Coroutine, Generator from contextlib import ( - AbstractAsyncContextManager, - AbstractContextManager, + AsyncExitStack, ExitStack, asynccontextmanager, contextmanager, + nullcontext, ) from dataclasses import dataclass, field -from typing import Any, AsyncContextManager, ContextManager, ParamSpec, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + ContextManager, + ParamSpec, + TypeVar, +) -Dependency = Callable[..., Any] +from nanodi.scopes import NullScope, Scope, SingletonScope +if TYPE_CHECKING: + from inspect import BoundArguments + +Dependency = Callable[..., Any] +T = TypeVar("T") +P = ParamSpec("P") +TC = TypeVar("TC", bound=Callable) _unset = object() -_resources_exit_stack = ExitStack() -_resources: dict[Dependency, AsyncContextManager | ContextManager] = {} -_resources_result_cache: dict[Dependency, Any] = {} +_exit_stack = ExitStack() +_async_exit_stack = AsyncExitStack() +_resources: list[Depends] = [] +_lock = threading.RLock() +_scopes: dict[str, Scope] = { + "null": NullScope(), + "singleton": SingletonScope(), +} -def Depends(dependency: Dependency, /, use_cache: bool = True) -> Any: # noqa: N802 - if dependency in _resources and not use_cache: - raise ValueError("use_cache=False is not supported for resources") - return _Depends(dependency, use_cache) +def Provide(dependency: Dependency, /, use_cache: bool = True) -> Any: # noqa: N802 + """ + Declare a provider. + It takes a single "dependency" callable (like a function). + Don't call it directly, nanodi will call it for you. + Dependency can be a regular function or a generator with one yield. + If the dependency is a generator, it will be used as a context manager. + Any generator that is valid for `contextlib.contextmanager` + can be used as a dependency. + Example: + ``` + from functools import lru_cache + from nanodi import Provide, inject + from my_conf import Settings -@dataclass(frozen=True) -class _Depends: - dependency: Dependency - use_cache: bool + def get_db(): + yield "db connection" + print("closing db connection") + @lru_cache # for calling the dependency only once + def get_settings(): + return Settings() -T = TypeVar("T") -P = ParamSpec("P") + @inject + def my_service(db=Provide(get_db), settings=Provide(get_settings)): + assert db == "db connection" + assert isinstance(settings, Settings) + ``` + """ + if not getattr(dependency, "_scope_", None): + dependency._scope_ = "null" # type: ignore[attr-defined] # noqa: SF01 + return Depends.from_dependency(dependency, use_cache) -@dataclass(frozen=True) -class ResolvedDependency: - original: Dependency - context_manager: ContextManager | AsyncContextManager | None = field(compare=False) - is_async: bool = field(default=False, compare=False) - use_cache: bool = True +def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]: + """ + Decorator to inject dependencies into a function. + Use it in combination with `Provide` to declare dependencies. - @classmethod - def resolve(cls, depends: _Depends) -> ResolvedDependency: - context_manager: ContextManager | AsyncContextManager | None = None - is_async = False - if inspect.isasyncgenfunction(depends.dependency): - context_manager = asynccontextmanager(depends.dependency)() - is_async = True - elif inspect.isgeneratorfunction(depends.dependency): - context_manager = contextmanager(depends.dependency)() - return cls(depends.dependency, context_manager, is_async, depends.use_cache) + Example: + ``` + from nanodi import inject, Provide + @inject + def my_service(db=Provide(some_dependency_func)): + ... + ``` + """ + signature = inspect.signature(fn) + if inspect.iscoroutinefunction(fn): -TC = TypeVar("TC", bound=Callable) + @functools.wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + exit_stack = AsyncExitStack() + for depends, names, get_value in _resolve_depends( + bound, exit_stack, is_async=True + ): + if depends.use_cache: + value = await get_value() + bound.arguments.update({name: value for name in names}) + else: + bound.arguments.update({name: await get_value() for name in names}) + async with exit_stack: + result = await fn(*bound.args, **bound.kwargs) + return result -def resource(fn: TC) -> TC: - manager: ContextManager | AsyncContextManager - if inspect.isasyncgenfunction(fn): - manager = asynccontextmanager(fn)() - elif inspect.isgeneratorfunction(fn): - manager = contextmanager(fn)() else: - raise ValueError("Resource must be a generator or async generator function") - _resources[fn] = manager - _resources_result_cache[fn] = _unset + @functools.wraps(fn) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + bound = signature.bind(*args, **kwargs) + bound.apply_defaults() + exit_stack = ExitStack() + for depends, names, get_value in _resolve_depends( + bound, exit_stack, is_async=False + ): + if depends.use_cache: + value = get_value() + bound.arguments.update({name: value for name in names}) + else: + bound.arguments.update({name: get_value() for name in names}) + + with exit_stack: + result = fn(*bound.args, **bound.kwargs) + return result + + return wrapper + + +def resource(fn: TC) -> TC: + """ + Decorator to declare a resource. Resource is a dependency that should be + called only once, cached and shared across the application. + On shutdown, all resources will be closed + (you need to call `shutdown_resources` manually). + Use it with a dependency generator function to declare a resource. + + Example: + ``` + from nanodi import resource + + # will be called only once + @resource + def get_db(): + yield "db connection" + print("closing db connection") + ``` + """ + if not inspect.isgeneratorfunction(fn) and not inspect.isasyncgenfunction(fn): + raise TypeError("Resource should be a generator function") + fn._scope_ = "singleton" # type: ignore[attr-defined] # noqa: SF01 + with _lock: + _resources.append(Depends.from_dependency(fn, use_cache=True)) return fn -def inject(fn: Callable[P, T]) -> Callable[P, T | Coroutine[Any, Any, T]]: - signature = inspect.signature(fn) +def init_resources() -> Awaitable: + """ + Call this function to close all resources. Usually, it should be called + when your application is shutting down. + """ + async_resources = [] + for depends in _resources: + if depends.is_async: + async_resources.append( + _get_value_from_depends_async(depends, _async_exit_stack) + ) + else: + _get_value_from_depends(depends, _exit_stack) - @functools.wraps(fn) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - bound = signature.bind(*args, **kwargs) - bound.apply_defaults() + return asyncio.gather(*async_resources) - dependencies: dict[ResolvedDependency, list[str]] = {} - for name, value in bound.arguments.items(): - if isinstance(value, _Depends): - dependencies.setdefault(ResolvedDependency.resolve(value), []).append( - name - ) - with _call_dependencies(dependencies) as arguments: - bound.arguments.update(arguments) - return fn(*bound.args, **bound.kwargs) +def shutdown_resources() -> Awaitable: + """ + Call this function to close all resources. Usually, it should be called + when your application is shutting down. + """ + _exit_stack.close() + return _async_exit_stack.aclose() - return wrapper +@dataclass(frozen=True) +class Depends: + dependency: Dependency + use_cache: bool + context_manager: ContextManager | AsyncContextManager | None = field(compare=False) + is_async: bool = field(compare=False) + + def get_scope_name(self) -> str: + return self.dependency._scope_ # type: ignore[attr-defined] # noqa: SF01 + + def value_as_context_manager(self) -> Any: + if self.context_manager: + return self.context_manager + return nullcontext(self.dependency()) -def shutdown_resources() -> None: - _resources_exit_stack.close() - - -@contextmanager -def _call_dependencies( - dependencies: dict[ResolvedDependency, list[str]], -) -> Generator[dict[str, Any], None, None]: - managers: list[tuple[AbstractContextManager, list[str]]] = [] - async_managers: list[tuple[AbstractAsyncContextManager, list[str]]] = [] - results = {} - for dependency, names in dependencies.items(): - if context_manager := _resources.get(dependency.original): - if isinstance(context_manager, AbstractContextManager): - if _resources_result_cache.get(dependency.original) is _unset: - result = _resources_exit_stack.enter_context(context_manager) - _resources_result_cache[dependency.original] = result - - result = _resources_result_cache[dependency.original] - results.update({name: result for name in names}) - elif dependency.context_manager: - if isinstance(dependency.context_manager, AbstractAsyncContextManager): - async_managers.append((dependency.context_manager, names)) - else: - managers.append((dependency.context_manager, names)) + @classmethod + def from_dependency(cls, dependency: Dependency, use_cache: bool) -> Depends: + context_manager: ContextManager | AsyncContextManager | None = None + is_async = False + if inspect.isasyncgenfunction(dependency): + context_manager = asynccontextmanager(dependency)() + is_async = True + elif inspect.isgeneratorfunction(dependency): + context_manager = contextmanager(dependency)() + + return cls(dependency, use_cache, context_manager, is_async) + + +def _resolve_depends( + bound: BoundArguments, exit_stack: AsyncExitStack | ExitStack, is_async: bool +) -> Generator[tuple[Depends, list[str], Callable[[], Any]], None, None]: + dependencies: dict[Depends, list[str]] = {} + for name, value in bound.arguments.items(): + if isinstance(value, Depends): + dependencies.setdefault(value, []).append(name) + + get_val = _get_value_from_depends_async if is_async else _get_value_from_depends + + for depends, names in dependencies.items(): + get_value = functools.partial(get_val, depends, exit_stack) # type: ignore + yield depends, names, get_value + + +def _get_value_from_depends( + depends: Depends, + local_exit_stack: ExitStack, +) -> Any: + scope_name = depends.get_scope_name() + scope = _scopes[scope_name] + try: + value = scope.get(depends.dependency) + except KeyError: + context_manager = depends.value_as_context_manager() + exit_stack = local_exit_stack + if scope_name == "singleton": + exit_stack = _exit_stack + if depends.is_async: + value = depends.dependency else: - if dependency.use_cache: - result = dependency.original() - results.update({name: result for name in names}) - else: - results.update({name: dependency.original() for name in names}) - - with ExitStack() as stack: - values = {manager: stack.enter_context(manager) for manager, _ in managers} - for manager, names in managers: - for name in names: - results[name] = values[manager] - yield results + with _lock: + try: + value = scope.get(depends.dependency) + except KeyError: + value = exit_stack.enter_context(context_manager) + scope.set(depends.dependency, value) + return value + + +async def _get_value_from_depends_async( + depends: Depends, + local_exit_stack: AsyncExitStack, +) -> Any: + scope_name = depends.get_scope_name() + scope = _scopes[scope_name] + try: + value = scope.get(depends.dependency) + except KeyError: + context_manager = depends.value_as_context_manager() + exit_stack = local_exit_stack + if scope_name == "singleton": + exit_stack = _async_exit_stack + with _lock: + try: + value = scope.get(depends.dependency) + except KeyError: + if depends.is_async: + value = await exit_stack.enter_async_context(context_manager) + else: + value = exit_stack.enter_context(context_manager) + scope.set(depends.dependency, value) + return value diff --git a/nanodi/providers.py b/nanodi/providers.py deleted file mode 100644 index 8225eaf..0000000 --- a/nanodi/providers.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from types import TracebackType - - -class Resource: - def init(self) -> Any: - raise NotImplementedError - - def close(self) -> None: - raise NotImplementedError - - def __enter__(self) -> Any: - return self.init() - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - traceback: TracebackType | None, - ) -> None: - self.close() - - -class AsyncResource: - async def init(self) -> Any: - raise NotImplementedError - - async def close(self) -> None: - raise NotImplementedError - - async def __aenter__(self) -> Any: - return await self.init() - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc: BaseException | None, - traceback: TracebackType | None, - ) -> None: - await self.close() diff --git a/nanodi/scopes.py b/nanodi/scopes.py new file mode 100644 index 0000000..746daf6 --- /dev/null +++ b/nanodi/scopes.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Hashable + + +class Scope: + def get(self, key: Hashable) -> Any: + raise NotImplementedError + + def set(self, key: Hashable, value: Any) -> None: + raise NotImplementedError + + +class NullScope(Scope): + def get(self, key: Hashable) -> Any: + raise KeyError(key) + + def set(self, key: Hashable, value: Any) -> None: + pass + + +class SingletonScope(Scope): + def __init__(self) -> None: + self._store: dict[Hashable, Any] = {} + + def get(self, key: Hashable) -> Any: + return self._store[key] + + def set(self, key: Hashable, value: Any) -> None: + self._store[key] = value diff --git a/setup.cfg b/setup.cfg index e99f2e1..2331a4a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,7 @@ per-file-ignores = ### Plugins # flake8-bugbear -extend-immutable-calls = nanodi.Depends, Depends +extend-immutable-calls = nanodi.Provide, Provide # flake8-pytest-style diff --git a/tests/test_complex_logic.py b/tests/test_complex_logic.py index d67947f..aa4d60e 100644 --- a/tests/test_complex_logic.py +++ b/tests/test_complex_logic.py @@ -1,6 +1,6 @@ from __future__ import annotations -from nanodi import Depends, inject +from nanodi import Provide, inject def get_redis() -> str: @@ -8,7 +8,7 @@ def get_redis() -> str: @inject -def get_sessions_storage(redis: str = Depends(get_redis)) -> str: +def get_sessions_storage(redis: str = Provide(get_redis)) -> str: return f"SessionsStorage({redis})" @@ -17,19 +17,19 @@ def get_postgres_connection() -> str: @inject -def get_db(postgres: str = Depends(get_postgres_connection)) -> str: +def get_db(postgres: str = Provide(get_postgres_connection)) -> str: return f"{postgres} DB" @inject -def get_users_repository(db: str = Depends(get_db)) -> str: +def get_users_repository(db: str = Provide(get_db)) -> str: return f"UsersRepository({db})" @inject def get_users_service( - users_repository: str = Depends(get_users_repository), - sessions_storage: str = Depends(get_sessions_storage), + users_repository: str = Provide(get_users_repository), + sessions_storage: str = Provide(get_sessions_storage), ) -> str: return f"UsersService({users_repository}, {sessions_storage})" diff --git a/tests/test_sync_di.py b/tests/test_sync_di.py index 11f400e..f923ba6 100644 --- a/tests/test_sync_di.py +++ b/tests/test_sync_di.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from nanodi import Depends, inject +from nanodi import Provide, inject @dataclass @@ -14,17 +14,17 @@ def get_redis() -> Redis: def test_resolve_dependency(): @inject - def my_service(redis: Redis = Depends(get_redis)): + def my_service(redis: Redis = Provide(get_redis)): return redis redis = my_service() - assert isinstance(redis, Redis) + assert isinstance(redis, Redis), redis def test_can_pass_dependency(): @inject - def my_service(redis: Redis | str = Depends(get_redis)): + def my_service(redis: Redis | str = Provide(get_redis)): return redis redis = my_service(redis="override") @@ -35,7 +35,7 @@ def my_service(redis: Redis | str = Depends(get_redis)): def test_dependencies_in_single_call_must_use_cache(): @inject def my_service( - redis1: Redis = Depends(get_redis), redis2: Redis = Depends(get_redis) + redis1: Redis = Provide(get_redis), redis2: Redis = Provide(get_redis) ): return redis1, redis2 @@ -47,7 +47,7 @@ def my_service( def test_dependencies_dont_share_cache_between_calls(): @inject - def my_service(redis: Redis = Depends(get_redis)): + def my_service(redis: Redis = Provide(get_redis)): return redis redis1 = my_service() @@ -61,8 +61,8 @@ def my_service(redis: Redis = Depends(get_redis)): def test_dependencies_in_single_call_dont_use_cache_if_specified(): @inject def my_service( - redis1: Redis = Depends(get_redis, use_cache=False), - redis2: Redis = Depends(get_redis, use_cache=False), + redis1: Redis = Provide(get_redis, use_cache=False), + redis2: Redis = Provide(get_redis, use_cache=False), ): return redis1, redis2 @@ -75,11 +75,11 @@ def my_service( def test_nested_dependencies(): @inject - def my_service_inner(redis: Redis = Depends(get_redis)): + def my_service_inner(redis: Redis = Provide(get_redis)): return redis @inject - def my_service_outer(inner_service: Redis = Depends(my_service_inner)): + def my_service_outer(inner_service: Redis = Provide(my_service_inner)): return inner_service inner_service = my_service_outer() diff --git a/tests/test_sync_di_with_closing.py b/tests/test_sync_di_with_closing.py index 2650a54..90ed281 100644 --- a/tests/test_sync_di_with_closing.py +++ b/tests/test_sync_di_with_closing.py @@ -5,7 +5,7 @@ import pytest -from nanodi import Depends, inject +from nanodi import Provide, inject if TYPE_CHECKING: from collections.abc import Generator @@ -33,7 +33,7 @@ def get_redis() -> Generator[Redis, None, None]: def test_resolve_dependency(): @inject - def my_service(redis: Redis = Depends(get_redis)): + def my_service(redis: Redis = Provide(get_redis)): return redis redis = my_service() @@ -44,7 +44,7 @@ def my_service(redis: Redis = Depends(get_redis)): @pytest.mark.parametrize("use_cache", [True, False]) def test_close_dependency_after_call(use_cache): @inject - def my_service(redis: Redis = Depends(get_redis, use_cache=use_cache)): + def my_service(redis: Redis = Provide(get_redis, use_cache=use_cache)): redis.make_request() return redis diff --git a/tests/test_sync_resource.py b/tests/test_sync_resource.py index 1101fd6..c2be207 100644 --- a/tests/test_sync_resource.py +++ b/tests/test_sync_resource.py @@ -5,7 +5,7 @@ import pytest -from nanodi import Depends, inject, resource, shutdown_resources +from nanodi import Provide, inject, resource, shutdown_resources if TYPE_CHECKING: from collections.abc import Generator @@ -38,7 +38,7 @@ def get_redis() -> Generator[Redis, None, None]: def test_resources_dont_close_automatically(redis_dependency): @inject - def my_service(redis: Redis = Depends(redis_dependency)): + def my_service(redis: Redis = Provide(redis_dependency)): redis.make_request() return redis @@ -49,7 +49,7 @@ def my_service(redis: Redis = Depends(redis_dependency)): def test_resources_can_be_closed_manually(redis_dependency): @inject - def my_service(redis: Redis = Depends(redis_dependency)): + def my_service(redis: Redis = Provide(redis_dependency)): redis.make_request() return redis @@ -57,13 +57,3 @@ def my_service(redis: Redis = Depends(redis_dependency)): shutdown_resources() assert redis.closed is True - - -def test_resources_cant_be_used_if_specified_without_cache(redis_dependency): - with pytest.raises( - ValueError, match="use_cache=False is not supported for resources" - ): - - @inject - def my_service(redis1: Redis = Depends(redis_dependency, use_cache=False)): - return redis1