Skip to content

Commit

Permalink
SpeedrunClient as an async context manager
Browse files Browse the repository at this point in the history
Allows a single session to be used for multiple requests.
  • Loading branch information
ManicJamie committed May 25, 2024
1 parent 1d08be1 commit 84f711e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 20 deletions.
87 changes: 68 additions & 19 deletions src/speedruncompy/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64, json
import logging
import asyncio, aiohttp
import sys
import random
from typing import Awaitable, Callable, Any, Generic, TypeVar

Expand All @@ -10,7 +11,7 @@
from .datatypes import Datatype, LenientDatatype, Pagination
from .exceptions import *

API_URI = "https://www.speedrun.com/api/v2/"
API_ROOT = "/api/v2/"
LANG = "en"
ACCEPT = "application/json"
DEFAULT_USER_AGENT = "speedruncompy/"
Expand All @@ -20,20 +21,51 @@

class SpeedrunClient():
"""Api class. Holds a unique PHPSESSID and user_agent, as well as its own logger."""
def __init__(self, user_agent: str | None = None) -> None:
self.cookie_jar = aiohttp.CookieJar()
self._header = {"Accept-Language": LANG, "Accept": ACCEPT, "User-Agent": f"{DEFAULT_USER_AGENT}{user_agent}"}

_session: aiohttp.ClientSession | None
cookie_jar: aiohttp.CookieJar | None
loose_cookies: dict[str, str]
_header: dict[str, str]

def __init__(self, user_agent: str | None = None, PHPSESSID: str | None = None) -> None:
self.cookie_jar = None
self._session = None
self.loose_cookies = {}
if PHPSESSID is not None:
self.loose_cookies["PHPSESSID"] = PHPSESSID
self._header = {"Accept-Language": LANG, "Accept": ACCEPT,
"User-Agent": f"{DEFAULT_USER_AGENT}{user_agent}"}
if user_agent is None:
self._log = _log
else:
self._log = _log.getChild(user_agent)

async def __aenter__(self):
self._session = await (await self._construct_session()).__aenter__()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if self._session is None: return
await self._session.__aexit__(exc_type, exc_val, exc_tb)
self._session = None

async def _construct_session(self):
if self.cookie_jar is None:
self.cookie_jar = aiohttp.CookieJar()
self.cookie_jar.update_cookies(self.loose_cookies)
return aiohttp.ClientSession(base_url="https://www.speedrun.com", cookie_jar=self.cookie_jar, headers=self._header,
json_serialize=lambda o: json.dumps(o, separators=(",", ":")))

def _get_PHPSESSID(self) -> str | None:
if self.cookie_jar is None: return self.loose_cookies.get("PHPSESSID", None)
cookie = self.cookie_jar.filter_cookies(URL("/")).get("PHPSESSID")
return None if cookie is None else cookie.value

def _set_PHPSESSID(self, phpsessid):
self.cookie_jar.update_cookies({"PHPSESSID": phpsessid})
if self.cookie_jar is not None:
self.cookie_jar.update_cookies({"PHPSESSID": phpsessid})
else:
self.loose_cookies.update({"PHPSESSID": phpsessid})

PHPSESSID = property(_get_PHPSESSID, _set_PHPSESSID)
"""Login token. Set by `PutAuthLogin`, or you may set it manually to a logged in session."""
Expand All @@ -45,24 +77,41 @@ def _encode_r(params: dict):
return base64.urlsafe_b64encode(paramsjson).replace(b"=", b"").decode()

async def do_get(self, endpoint: str, params: dict = {}) -> tuple[bytes, int]:
# Params passed to the API by the site are json-base64 encoded, even though std params are supported.
# We will do the same in case param support is retracted.
self._log.debug(f"GET {endpoint} w/ params {params}")
async with aiohttp.ClientSession(headers=self._header, cookie_jar=self.cookie_jar) as session:
async with session.get(url=f"{API_URI}{endpoint}", params={"_r": self._encode_r(params)}) as response:
return (await response.read(), response.status)
session = self._session
if session is None:
session = await (await self._construct_session()).__aenter__()

async def do_post(self, endpoint: str, params: dict = {}, _setCookie=True) -> tuple[bytes, int]:
# Construct a dummy jar if we wish to ignore Set_Cookie responses (only on PutAuthLogin and PutAuthSignup)
if _setCookie: liveJar = self.cookie_jar
try:
async with session.get(url=f"{API_ROOT}{endpoint}", params={"_r": self._encode_r(params)}) as response:
out = (await response.read(), response.status)
except Exception as e:
if self._session is not None:
await self._session.__aexit__(*sys.exc_info())
raise e
else:
liveJar = aiohttp.CookieJar()
liveJar.update_cookies(self.cookie_jar._cookies)
if self._session is None:
await session.__aexit__(None, None, None)
return out

async def do_post(self, endpoint: str, params: dict = {}) -> tuple[bytes, int]:
self._log.debug(f"POST {endpoint} w/ params {params}")
async with aiohttp.ClientSession(json_serialize=lambda o: json.dumps(o, separators=(",", ":")),
headers=self._header, cookie_jar=liveJar) as session:
async with session.post(url=f"{API_URI}{endpoint}", json=params) as response:
return (await response.read(), response.status)

session = self._session
if session is None:
session = await (await self._construct_session()).__aenter__()

try:
async with session.post(url=f"{API_ROOT}{endpoint}", json=params) as response:
out = (await response.read(), response.status)
except Exception as e:
if self._session is not None:
await self._session.__aexit__(*sys.exc_info())
raise e
else:
if self._session is None:
await session.__aexit__(None, None, None)
return out


_default = SpeedrunClient()
Expand Down
23 changes: 22 additions & 1 deletion test/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def disable_type_checking():
@pytest.fixture(autouse=True)
def check_api_conformance():
"""The default API must never have a PHPSESSID."""
assert len(_default.cookie_jar._cookies) == 0
if _default.cookie_jar is None:
assert len(_default.loose_cookies) == 0
else:
assert len(_default.cookie_jar._cookies) == 0
yield

def log_result(result: Datatype | dict):
Expand Down Expand Up @@ -129,6 +132,24 @@ def test_DefaultAPI_separation(self):
session = GetSession(_api=self.api).perform()
assert "signedIn" in session.session
assert session.session.signedIn is True, "High-auth api not signed in"

def test_API_Context_Manager(self):
"""Ensure expected behaviour from the context manager interface."""
async def test():
async with SpeedrunClient("Test", SESSID) as client:
assert client._session is not None
result = await GetSession(_api=client).perform_async(autovary=True)
result2 = await GetSession(_api=client).perform_async(autovary=True)

log_result(result)
check_datatype_coverage(result)

log_result(result2)
check_datatype_coverage(result2)

assert client._session is None

asyncio.run(test())

@pytest.mark.skip(reason="Test stub")
def test_Authflow(self):
Expand Down

0 comments on commit 84f711e

Please sign in to comment.