Skip to content

Commit

Permalink
Tidy Datatype, fix errors induced by debuggers & OptField[Optional[li…
Browse files Browse the repository at this point in the history
…st[str]]]
  • Loading branch information
ManicJamie committed Apr 13, 2024
1 parent e7ae20f commit 48c065c
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 70 deletions.
114 changes: 60 additions & 54 deletions src/speedruncompy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

from enum import Enum
from numbers import Real
from types import NoneType
from typing import Any, Optional, Union, get_type_hints, get_origin, get_args, _SpecialForm, _type_check
from json import JSONEncoder
import typing

from .enums import *
from .exceptions import IncompleteDatatype
Expand All @@ -37,7 +39,7 @@ class _OptFieldMarker(): pass
def OptField(self, parameters):
"""Field that may not be present. Will return `None` if not present."""
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(_OptFieldMarker)]
return Union[arg, _OptFieldMarker]

class srcpyJSONEncoder(JSONEncoder):
"""Converts Datatypes to dicts when encountered"""
Expand All @@ -48,14 +50,34 @@ def default(self, o: Any) -> Any:

_log = logging.getLogger("speedruncompy.datatypes")

def is_optional(field):
return get_origin(field) is Union and type(_OptFieldMarker) in get_args(field)

def get_true_type(hint):
return get_args(hint)[0] if (get_origin(hint) is Union) or (get_origin(hint) is Optional) else hint

def get_optional_type(hint: type):
if get_origin(hint) is Union and type(None) in get_args(hint): return Union[get_true_type(hint), None]
def is_optional_field(field):
return get_origin(field) is Union and _OptFieldMarker in get_args(field)

def is_optional(test: type):
return (get_origin(test) is Union and (NoneType in get_args(test))) or (get_origin(test) is Optional)

def is_compliant_type(hint: type):
"""Whether the type of the response should be coerced to the hint"""
hint = degrade_union(hint, _OptFieldMarker, NoneType)
if get_origin(hint) == list: return False
return issubclass(hint, Datatype) or issubclass(hint, Enum) or hint == float

def is_type(value, hint: type):
if value is None:
return is_optional(hint)
else:
check = degrade_union(hint, _OptFieldMarker, NoneType)
check = get_origin(check) if get_origin(check) is not None else check
return isinstance(value, check)

def degrade_union(union: type, *to_remove: type):
"""Removes types from a union type."""
if get_origin(union) is typing.Optional:
union = get_args(union)[0] # In case Optional does not become Union (hates me)
if get_origin(union) is Union:
newargs : set = set(get_args(union)) - set(to_remove)
return Union[tuple(newargs)]
return union

class Datatype():
def __init__(self, template: Union[dict, tuple, "Datatype", None] = None, skipChecking: bool = False) -> None:
Expand All @@ -76,51 +98,34 @@ def get_type_hints(cls) -> dict[str, Any]:
return get_type_hints(cls)

def enforce_types(self):
#TODO: This is the messiest function i've ever written. do better, me (to be fixed Soon(tm))
"""Enforces this datatype's fields to conform to specified types."""
hints = self.get_type_hints()
missing_attrs = []
for attr, hint in hints.items():
base_hint = get_true_type(hint)
if get_origin(hint) is list:
list_subhint = get_args(hint)[0]
else:
list_subhint = None
if attr not in self.__dict__:
if is_optional(hint): continue
else: missing_attrs.append(attr)
elif issubclass(base_hint, Datatype) or issubclass(base_hint, Enum) or base_hint == float:
if not isinstance(self[attr], base_hint):
self[attr] = base_hint(self[attr]) # Force contained types to comply
elif get_origin(hint) is list:
list_subhint = get_args(hint)[0]
if issubclass(list_subhint, Datatype) or issubclass(list_subhint, Enum):
raw = self[attr]
self[attr] = []
for r in raw:
self[attr].append(list_subhint(r))

if len(missing_attrs) > 0:
if STRICT_TYPE_CONFORMANCE: raise IncompleteDatatype(f"Datatype {type(self).__name__} constructed missing mandatory fields {missing_attrs}")
else: _log.warning(f"Datatype {type(self).__name__} constructed missing mandatory fields {missing_attrs}")

opt_hint = get_optional_type(hint)
check = get_origin(base_hint) if get_origin(base_hint) is not None else base_hint
if check == Any:
_log.debug(f"Undocumented attr {attr} has value {self[attr]} of type {type(self[attr])}")
continue # Can't do enforcement against Any
if not isinstance(self[attr], check) and not isinstance(self[attr], opt_hint):
if STRICT_TYPE_CONFORMANCE:
raise AttributeError(f"Datatype {type(self).__name__}'s attribute {attr} expects {check} but received {type(self[attr]).__name__}")
else: _log.warning(f"Datatype {type(self).__name__}'s attribute {attr} expects {check} but received {type(self[attr]).__name__}")
if isinstance(self[attr], list) and len(self[attr]) > 0:
instance = self[attr][0]
subhints = get_args(hint)
list_subhint = subhints[0]
if not isinstance(instance, list_subhint):
if STRICT_TYPE_CONFORMANCE:
raise AttributeError(f"Datatype {type(self).__name__}'s attribute {attr} expects list[{list_subhint}] but received {type(self[attr][0]).__name__}")
else: _log.warning(f"Datatype {type(self).__name__}'s attribute {attr} expects list[{list_subhint}] but received {type(self[attr][0]).__name__}")

missing_fields = set() # fields that are specified as non-optional that are missing from
for fieldname, hint in hints.items():
nullable_type = degrade_union(hint, _OptFieldMarker) # type that may be nullable but not optional
true_type = degrade_union(nullable_type, NoneType) # base type (no union)
raw = self[fieldname]

if fieldname not in self.__dict__:
if is_optional_field(hint): continue
else: missing_fields.add(fieldname) # Non-optional fields must be present, report if not
elif is_compliant_type(true_type):
if not isinstance(self[fieldname], true_type):
self[fieldname] = true_type(raw) # Coerce compliant types
elif get_origin(true_type) is list:
list_type = get_args(true_type)[0]
if is_compliant_type(list_type): # Coerce list types
self[fieldname] = [list_type(r) if not isinstance(self[fieldname], list_type) else r for r in raw]

attr = self[fieldname]
if true_type == Any: _log.debug(f"Undocumented attr {fieldname} has value {raw} of type {type(raw)}")
elif not is_type(attr, hint):
if STRICT_TYPE_CONFORMANCE: raise AttributeError(f"Datatype {type(self).__name__}'s attribute {fieldname} expects {nullable_type} but received {type(attr).__name__}")
else: _log.warning(f"Datatype {type(self).__name__}'s attribute {attr} expects {nullable_type} but received {type(self[attr]).__name__}")

if len(missing_fields) > 0:
if STRICT_TYPE_CONFORMANCE: raise IncompleteDatatype(f"Datatype {type(self).__name__} constructed missing mandatory fields {missing_fields}")
else: _log.warning(f"Datatype {type(self).__name__} constructed missing mandatory fields {missing_fields}")


# Allow interacting with these types as if they were dicts (in all reasonable ways)
Expand All @@ -146,6 +151,7 @@ def __getitem__(self, key):
else: raise e
# __getattr__ only called for missing attributes
def __getattr__(self, __name: str) -> Any:
if __name.startswith("__"): return None # Special handling for python reserved calls
if __name in get_type_hints(self.__class__).keys(): return None #TODO: warn here?
else: raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{__name}'" ,name=__name, obj=self)

Expand Down Expand Up @@ -636,7 +642,7 @@ class Run(Datatype):
orphaned: OptField[bool]
estimated: OptField[bool] #TODO: Figure out what this means
"""Only shown in GetModerationRuns"""
issues: OptField[Optional[list]] #TODO: fails when present
issues: OptField[Optional[list[str]]] #TODO: fails when present

class ChallengeStanding(Datatype):
challengeId: str
Expand Down
7 changes: 5 additions & 2 deletions src/speedruncompy/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,11 @@ def _combine_results(self, pages: dict):
for v in (v for v in page["values"] if v not in values): values.append(v)
for v in (v for v in page["variables"] if v not in variables): variables.append(v)
runs += page["runs"]
return {"categories": categories, "games": games, "levels": levels, "platforms": platforms, "players": players,
"regions": regions, "runs": runs, "users": users, "values": values, "variables": variables}

extras: r_GetModerationRuns = pages[1]
extras.pagination.page = 0
return extras | r_GetModerationRuns({"categories": categories, "games": games, "levels": levels, "platforms": platforms, "players": players,
"regions": regions, "runs": runs, "users": users, "values": values, "variables": variables}, skipChecking=True)

class PutRunAssignee(PostRequest):
def __init__(self, assigneeId: str, runId: str, **params) -> None:
Expand Down
6 changes: 2 additions & 4 deletions test/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,13 @@ def test_GetModerationRuns(self):
log_result(result)
check_datatype_coverage(result)

@pytest.mark.skip(reason="Test stub")
def test_GetModerationRuns_paginated(self):
result = GetModerationRuns(_api=self.api, gameId=game_id).perform_all()
result = GetModerationRuns(_api=self.api, gameId=game_id, verified=verified.PENDING).perform_all()
log_result(result)
check_datatype_coverage(result)

@pytest.mark.skip(reason="Test stub")
def test_GetModerationRuns_paginated_raw(self):
result = ...
result = GetModerationRuns(_api=self.api, gameId=game_id, verified=verified.PENDING)._perform_all_raw()
log_result(result)
check_datatype_coverage(result)

Expand Down
14 changes: 4 additions & 10 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from types import NoneType
from typing import get_origin, get_args, Optional, Union, get_type_hints
from speedruncompy.datatypes import Datatype
from speedruncompy.datatypes import _OptFieldMarker, Datatype, degrade_union

def get_true_type(t: type):
origin = get_origin(t)
if origin is None: return t
else:
args = get_args(t)
if origin is Union or origin is Optional:
return args[0]
else:
return origin
return degrade_union(t, NoneType, _OptFieldMarker)

def check_datatype_coverage(dt: Datatype):
keys = set(dt.keys())
Expand All @@ -22,7 +16,7 @@ def check_datatype_coverage(dt: Datatype):
if issubclass(true, Datatype):
if dt[attr] is not None:
check_datatype_coverage(dt[attr])
elif true is list:
elif get_origin(true) is list:
list_type = get_args(subtype)[0]
if issubclass(list_type, Datatype):
for item in dt[attr]:
Expand Down

0 comments on commit 48c065c

Please sign in to comment.