From e6b48323d4c17414904b6ed2209917347ccb0aaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=83=8A=E3=83=8B=E3=82=AB?= <101696371+nanika2@users.noreply.github.com> Date: Fri, 12 Jan 2024 04:23:49 +1100 Subject: [PATCH] only close session if owner (#13) --- mystbin/client.py | 13 +++++++++++++ mystbin/http.py | 5 ++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mystbin/client.py b/mystbin/client.py index f9c09dd..980c703 100755 --- a/mystbin/client.py +++ b/mystbin/client.py @@ -30,8 +30,10 @@ if TYPE_CHECKING: import datetime + from types import TracebackType from aiohttp import ClientSession + from typing_extensions import Self __all__ = ("Client",) @@ -42,6 +44,17 @@ class Client: def __init__(self, *, token: str | None = None, session: ClientSession | None = None) -> None: self.http: HTTPClient = HTTPClient(token=token, session=session) + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_cls: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None + ) -> None: + await self.close() + async def close(self) -> None: """|coro| diff --git a/mystbin/http.py b/mystbin/http.py index 2ebd430..664f972 100755 --- a/mystbin/http.py +++ b/mystbin/http.py @@ -126,6 +126,7 @@ def __init__(self, verb: SupportedHTTPVerb, path: str, **params: Any) -> None: class HTTPClient: __slots__ = ( "_session", + "_owns_session", "_async", "_token", "_locks", @@ -135,16 +136,18 @@ class HTTPClient: def __init__(self, *, token: str | None, session: aiohttp.ClientSession | None = None) -> None: self._token: str | None = token self._session: aiohttp.ClientSession | None = session + self._owns_session: bool = False self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() user_agent = "mystbin.py (https://github.com/PythonistaGuild/mystbin.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) async def close(self) -> None: - if self._session: + if self._session and self._owns_session: await self._session.close() async def _generate_session(self) -> aiohttp.ClientSession: self._session = aiohttp.ClientSession() + self._owns_session = True return self._session async def request(self, route: Route, **kwargs: Any) -> Any: