Skip to content

Commit

Permalink
chore: pull facade type improvements from juju#1104
Browse files Browse the repository at this point in the history
  • Loading branch information
dimaqq committed Nov 28, 2024
1 parent 682c75d commit 039a9cb
Showing 1 changed file with 41 additions and 27 deletions.
68 changes: 41 additions & 27 deletions juju/client/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import defaultdict
from glob import glob
from pathlib import Path
from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, TypeVar, overload

import packaging.version
import typing_inspect
Expand Down Expand Up @@ -183,7 +183,7 @@ def ref_type(self, obj):
return self.get_ref_type(obj["$ref"])


CLASSES = {}
CLASSES: dict[str, type[Type]] = {}
factories = codegen.Capture()


Expand Down Expand Up @@ -479,37 +479,48 @@ def ReturnMapping(cls): # noqa: N802
def decorator(f):
@functools.wraps(f)
async def wrapper(*args, **kwargs):
nonlocal cls
reply = await f(*args, **kwargs)
if cls is None:
return reply
if "error" in reply:
cls = CLASSES["Error"]
if typing_inspect.is_generic_type(cls) and issubclass(
typing_inspect.get_origin(cls), Sequence
):
parameters = typing_inspect.get_parameters(cls)
result = []
item_cls = parameters[0]
for item in reply:
result.append(item_cls.from_json(item))
"""
if 'error' in item:
cls = CLASSES['Error']
else:
cls = item_cls
result.append(cls.from_json(item))
"""
else:
result = cls.from_json(reply["response"])

return result
return _convert_response(reply, cls=cls)

return wrapper

return decorator


@overload
def _convert_response(response: dict[str, Any], *, cls: type[SomeType]) -> SomeType: ...


@overload
def _convert_response(response: dict[str, Any], *, cls: None) -> dict[str, Any]: ...


def _convert_response(response: dict[str, Any], *, cls: type[Type] | None) -> Any:
if cls is None:
return response
if "error" in response:
cls = CLASSES["Error"]
if typing_inspect.is_generic_type(cls) and issubclass(
typing_inspect.get_origin(cls), Sequence
):
parameters = typing_inspect.get_parameters(cls)
result = []
item_cls = parameters[0]
for item in response:
result.append(item_cls.from_json(item))
"""
if 'error' in item:
cls = CLASSES['Error']
else:
cls = item_cls
result.append(cls.from_json(item))
"""
else:
result = cls.from_json(response["response"])

return result


def make_func(cls, name, description, params, result, _async=True):
indent = " "
args = Args(cls.schema, params)
Expand Down Expand Up @@ -663,7 +674,7 @@ async def rpc(self, msg: dict[str, _RichJson]) -> _Json:
return result

@classmethod
def from_json(cls, data):
def from_json(cls, data: Type | str | dict[str, Any] | list[Any]) -> Type | None:
def _parse_nested_list_entry(expr, result_dict):
if isinstance(expr, str):
if ">" in expr or ">=" in expr:
Expand Down Expand Up @@ -742,6 +753,9 @@ def get(self, key, default=None):
return getattr(self, attr, default)


SomeType = TypeVar("SomeType", bound=Type)


class Schema(dict):
def __init__(self, schema):
self.name = schema["Name"]
Expand Down

0 comments on commit 039a9cb

Please sign in to comment.