diff --git a/twitchio/authentication/oauth.py b/twitchio/authentication/oauth.py index 343e59a3..247b4c72 100644 --- a/twitchio/authentication/oauth.py +++ b/twitchio/authentication/oauth.py @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING, ClassVar from ..http import HTTPClient, Route +from ..utils import MISSING from .payloads import * @@ -55,7 +56,7 @@ def __init__( client_secret: str, redirect_uri: str | None = None, scopes: Scopes | None = None, - session: aiohttp.ClientSession | None = None, + session: aiohttp.ClientSession = MISSING, ) -> None: super().__init__(session=session, client_id=client_id) diff --git a/twitchio/authentication/tokens.py b/twitchio/authentication/tokens.py index 4cde8b7b..b061b76e 100644 --- a/twitchio/authentication/tokens.py +++ b/twitchio/authentication/tokens.py @@ -37,6 +37,7 @@ from ..exceptions import HTTPException, InvalidTokenException from ..http import HTTPAsyncIterator, PaginatedConverter from ..types_.tokens import TokenMappingData +from ..utils import MISSING from .oauth import OAuth from .payloads import ClientCredentialsPayload, ValidateTokenPayload from .scopes import Scopes @@ -61,7 +62,7 @@ def __init__( client_secret: str, redirect_uri: str | None = None, scopes: Scopes | None = None, - session: aiohttp.ClientSession | None = None, + session: aiohttp.ClientSession = MISSING, nested_key: str | None = None, ) -> None: super().__init__( diff --git a/twitchio/http.py b/twitchio/http.py index c4fe458b..13a6a569 100644 --- a/twitchio/http.py +++ b/twitchio/http.py @@ -377,11 +377,12 @@ async def __anext__(self) -> T: class HTTPClient: - __slots__ = ("_client_id", "_session", "_should_close", "user_agent") + __slots__ = ("_client_id", "_session", "_session_set", "_should_close", "user_agent") def __init__(self, session: aiohttp.ClientSession = MISSING, *, client_id: str) -> None: self._session: aiohttp.ClientSession = session self._should_close: bool = session is MISSING + self._session_set: bool = False self._client_id: str = client_id @@ -395,6 +396,11 @@ def headers(self) -> dict[str, str]: return {"User-Agent": self.user_agent, "Client-ID": self._client_id} async def _init_session(self) -> None: + if self._session_set: + return + + self._session_set = True + if self._session is not MISSING: self._session.headers.update(self.headers) return @@ -408,6 +414,7 @@ def clear(self) -> None: "Clearing %s session. A new session will be created on the next request.", self.__class__.__qualname__ ) self._session = MISSING + self._session_set = False async def close(self) -> None: if not self._should_close: