From e6b48323d4c17414904b6ed2209917347ccb0aaf Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E3=83=8A=E3=83=8B=E3=82=AB?=
 <101696371+nanika2@users.noreply.github.com>
Date: Fri, 12 Jan 2024 04:23:49 +1100
Subject: [PATCH] only close session if owner (#13)

---
 mystbin/client.py | 13 +++++++++++++
 mystbin/http.py   |  5 ++++-
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/mystbin/client.py b/mystbin/client.py
index f9c09dd..980c703 100755
--- a/mystbin/client.py
+++ b/mystbin/client.py
@@ -30,8 +30,10 @@
 
 if TYPE_CHECKING:
     import datetime
+    from types import TracebackType
 
     from aiohttp import ClientSession
+    from typing_extensions import Self
 
 __all__ = ("Client",)
 
@@ -42,6 +44,17 @@ class Client:
     def __init__(self, *, token: str | None = None, session: ClientSession | None = None) -> None:
         self.http: HTTPClient = HTTPClient(token=token, session=session)
 
+    async def __aenter__(self) -> Self:
+        return self
+
+    async def __aexit__(
+        self,
+        exc_cls: type[BaseException] | None,
+        exc_value: BaseException | None,
+        traceback: TracebackType | None
+    ) -> None:
+        await self.close()
+
     async def close(self) -> None:
         """|coro|
 
diff --git a/mystbin/http.py b/mystbin/http.py
index 2ebd430..664f972 100755
--- a/mystbin/http.py
+++ b/mystbin/http.py
@@ -126,6 +126,7 @@ def __init__(self, verb: SupportedHTTPVerb, path: str, **params: Any) -> None:
 class HTTPClient:
     __slots__ = (
         "_session",
+        "_owns_session",
         "_async",
         "_token",
         "_locks",
@@ -135,16 +136,18 @@ class HTTPClient:
     def __init__(self, *, token: str | None, session: aiohttp.ClientSession | None = None) -> None:
         self._token: str | None = token
         self._session: aiohttp.ClientSession | None = session
+        self._owns_session: bool = False
         self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
         user_agent = "mystbin.py (https://github.com/PythonistaGuild/mystbin.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}"
         self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
 
     async def close(self) -> None:
-        if self._session:
+        if self._session and self._owns_session:
             await self._session.close()
 
     async def _generate_session(self) -> aiohttp.ClientSession:
         self._session = aiohttp.ClientSession()
+        self._owns_session = True
         return self._session
 
     async def request(self, route: Route, **kwargs: Any) -> Any: