Skip to content

Commit

Permalink
Make endpoints generic & NotImplemented _combine_results
Browse files Browse the repository at this point in the history
Still raising mypy errors on OptField, willFix
  • Loading branch information
ManicJamie committed Apr 28, 2024
1 parent da34e69 commit ca22ed6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 659 deletions.
45 changes: 24 additions & 21 deletions src/speedruncompy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .exceptions import *
import logging
import asyncio, aiohttp
from typing import Awaitable, Callable, Any
from typing import Awaitable, Callable, Any, Generic, TypeVar

from .datatypes import Datatype, srcpyJSONEncoder, LenientDatatype

Expand Down Expand Up @@ -47,7 +47,10 @@ async def do_get(self, endpoint: str, params: dict = {}) -> tuple[bytes, int]:

async def do_post(self, endpoint:str, params: dict = {}, _setCookie=True) -> tuple[bytes, int]:
# Construct a dummy jar if we wish to ignore Set_Cookie responses (only on PutAuthLogin and PutAuthSignup)
liveJar = self.cookie_jar if _setCookie else aiohttp.CookieJar().update_cookies(self.cookie_jar._cookies)
if _setCookie: liveJar = self.cookie_jar
else:
liveJar = aiohttp.CookieJar()
liveJar.update_cookies(self.cookie_jar._cookies)
self._log.debug(f"POST {endpoint} w/ params {params}")
async with aiohttp.ClientSession(json_serialize=lambda o: json.dumps(o, separators=(",", ":"), cls=srcpyJSONEncoder),
headers=self._header, cookie_jar=liveJar) as session:
Expand All @@ -59,11 +62,13 @@ async def do_post(self, endpoint:str, params: dict = {}, _setCookie=True) -> tup
def set_default_PHPSESSID(phpsessionid):
_default.cookie_jar.update({"PHPSESSID": phpsessionid})

class BaseRequest():
R = TypeVar('R', bound=Datatype)

class BaseRequest(Generic[R]):
def __init__(self,
method: Callable[[str, dict[str, Any]], Awaitable[tuple[bytes, int]]],
endpoint: str,
returns: type = LenientDatatype,
returns: type[R],
**params):
self.method = method
self.endpoint = endpoint
Expand All @@ -74,7 +79,7 @@ def update_params(self, **kwargs):
"""Updates parameters using values set in kwargs"""
self.params.update(kwargs)

def perform(self, retries=5, delay=1, **kwargs) -> Datatype:
def perform(self, retries=5, delay=1, **kwargs) -> R:
"""Synchronously perform the request.
NB: This uses its own event loop, so if using `asyncio` use `perform_async()` instead."""
Expand All @@ -83,22 +88,22 @@ def perform(self, retries=5, delay=1, **kwargs) -> Datatype:
except RuntimeError as e:
raise AIOException("Synchronous interface called from asynchronous context - use `await perform_async` instead.") from None

async def perform_async(self, retries=5, delay=1, **kwargs) -> Datatype:
async def perform_async(self, retries=5, delay=1, **kwargs) -> R:
"""Asynchronously perform the request. Remember to `await` me!"""
self.response = await self.method(self.endpoint, self.params | kwargs)
content = self.response[0]
status = self.response[1]

if (status >= 500 and status <= 599) or status == 408:
if retries > 0:
_log.error(f"SRC returned error {status} {content}. Retrying with delay {delay}:")
_log.error(f"SRC returned error {status} {content!r}. Retrying with delay {delay}:")
for attempt in range(0, retries+1):
self.response = await self.method(self.endpoint, self.params)
content = self.response[0]
status = self.response[1]
if not (status >= 500 and status <= 599) or status == 408:
break
_log.error(f"Retry {attempt} returned error {status} {content}")
_log.error(f"Retry {attempt} returned error {status} {content!r}")
await asyncio.sleep(delay)
else:
if status == 408: raise RequestTimeout(self)
Expand All @@ -115,51 +120,49 @@ async def perform_async(self, retries=5, delay=1, **kwargs) -> Datatype:
if (status >= 500 and status <= 599): raise ServerException(self)

if status < 200 or status > 299:
_log.error(f"Unknown response error returned from SRC! {status} {self.response[0]}")
_log.error(f"Unknown response error returned from SRC! {status} {self.response[0]!r}")
raise APIException(self)

return self.return_type(json.loads(content.decode()))

class BasePaginatedRequest(BaseRequest):
def _combine_results(self, pages: dict):
_log.warning(f"""perform_all or perform_all_async on {type(self).__name__} is NOT yet implemented!
Use _perform_all_raw() or _perform_all_async_raw() to protect against future updates.""")
class BasePaginatedRequest(BaseRequest[R], Generic[R]):
def _combine_results(self, pages: dict[int, R]) -> R:
raise NotImplementedError("perform_all or perform_all_async on {type(self).__name__} is NOT yet implemented! Use _perform_all_raw() or _perform_all_async_raw()")
return pages

def perform_all(self, retries=5, delay=1) -> dict:
def perform_all(self, retries=5, delay=1) -> R:
"""Returns a combined dict of all pages. `pagination` is removed."""
pages = self._perform_all_raw(retries, delay)
return self._combine_results(pages)

def _perform_all_raw(self, retries=5, delay=1) -> dict[int, dict]:
def _perform_all_raw(self, retries=5, delay=1) -> dict[int, R]:
"""Get all pages and return a dict of {pageNo : pageData}."""
try:
return asyncio.run(self._perform_all_async_raw(retries, delay))
except RuntimeError as e:
raise AIOException("Synchronous interface called from asynchronous context - use `await perform_async` instead.") from None

async def perform_all_async(self, retries=5, delay=1) -> dict:
async def perform_all_async(self, retries=5, delay=1) -> R:
"""Returns a combined dict of all pages. `pagination` is removed."""
pages = await self._perform_all_async_raw(retries, delay)
return self._combine_results(pages)

async def _perform_all_async_raw(self, retries=5, delay=1) -> dict[int, dict]:
async def _perform_all_async_raw(self, retries=5, delay=1) -> dict[int, R]:
"""Get all pages and return a dict of {pageNo : pageData}."""
self.pages: dict[int, Datatype] = {}
self.pages: dict[int, R] = {}
self.pages[1] = await self.perform_async(retries, delay, page=1)
numpages = self.pages[1]["pagination"]["pages"]
if numpages > 1:
results = await asyncio.gather(*[self.perform_async(retries, delay, page=p) for p in range(2, numpages + 1)])
self.pages.update({p + 2:result for p, result in enumerate(results)})
return self.pages

class GetRequest(BaseRequest):
class GetRequest(BaseRequest[R], Generic[R]):
def __init__(self, endpoint, returns:type=LenientDatatype, _api:SpeedrunComPy|None=None, **params) -> None:
if _api is None: _api = _default
super().__init__(method=_api.do_get, endpoint=endpoint, returns=returns, **params)

class PostRequest(BaseRequest):
class PostRequest(BaseRequest[R], Generic[R]):
def __init__(self, endpoint, returns:type=LenientDatatype, _api:SpeedrunComPy|None=None, **params) -> None:
if _api is None: _api = _default
super().__init__(method=_api.do_post, endpoint=endpoint, returns=returns, **params)
2 changes: 1 addition & 1 deletion src/speedruncompy/datatypes/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class Game(Datatype):
emulator: EmulatorType
defaultTimer: TimerName
validTimers: list[TimerName]
releaseDate: int
releaseDate: OptField[int]
addedDate: int
touchDate: int
baseGameId: OptField[str]
Expand Down
Loading

0 comments on commit ca22ed6

Please sign in to comment.