Skip to content

Commit

Permalink
add annotations extension in field & schema
Browse files Browse the repository at this point in the history
  • Loading branch information
voidZXL committed Mar 30, 2024
1 parent b5a729f commit 764f3fe
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 15 deletions.
2 changes: 1 addition & 1 deletion utype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
register_transformer = TypeTransformer.registry.register


VERSION = (0, 4, 1, None)
VERSION = (0, 5, 0, None)


def _get_version():
Expand Down
1 change: 1 addition & 0 deletions utype/parser/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import sys
import warnings
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union

from ..utils import exceptions as exc
Expand Down
12 changes: 11 additions & 1 deletion utype/parser/cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import warnings
from collections.abc import Mapping
from functools import partial
from types import FunctionType
from typing import Callable, Dict, Type, TypeVar

from ..utils import exceptions as exc
Expand Down Expand Up @@ -525,6 +524,17 @@ def __init__(_obj_self, _d: dict = None, **kwargs):

return __init__

@property
def schema_annotations(self):
# this is meant to be extended and override
# if the result is not None, it will become the x-annotation of the JSON schema output
data = dict()
if self.options.mode:
data.update(mode=self.options.mode)
if self.options.case_insensitive:
data.update(case_insensitive=self.options.case_insensitive)
return data


def init_dataclass(
cls: Type[T], data, options: Options = None, context: RuntimeContext = None
Expand Down
37 changes: 26 additions & 11 deletions utype/parser/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,14 @@ def __call__(self, fn_or_cls, *args, **kwargs):
setattr(fn_or_cls, "__field__", self)
return fn_or_cls

@property
def schema_annotations(self):
return {}

@property
def default_type(self):
return None


class Param(Field):
def __init__(
Expand Down Expand Up @@ -729,17 +737,18 @@ def is_case_insensitive(self, options: Options) -> bool:
# return value()
# return copy_value(value)

def get_default(self, options: Options, defer: bool = False):
def get_default(self, options: Options, defer: Optional[bool] = False):
# options = options or self.options
if options.no_default:
return unprovided

if not defer:
if self.defer_default or options.defer_default:
return unprovided
else:
if not self.defer_default and not options.defer_default:
return unprovided
if isinstance(defer, bool):
if not defer:
if self.defer_default or options.defer_default:
return unprovided
else:
if not self.defer_default and not options.defer_default:
return unprovided

if not unprovided(options.force_default):
default = options.force_default
Expand All @@ -763,10 +772,6 @@ def get_on_error(self, options: Options):
return self.on_error
return options.invalid_values

def get_example(self):
if not unprovided(self.field.example):
return self.field.example

def is_required(self, options: Options):
if options.ignore_required or not self.required:
return False
Expand Down Expand Up @@ -1068,6 +1073,12 @@ def get_field(cls, annotation: Any, default, **kwargs):
else:
return default

@property
def schema_annotations(self):
# this is meant to be extended and override
# if the result is not None, it will become the x-annotation of the JSON schema output
return self.field.schema_annotations

@classmethod
def generate(
cls,
Expand Down Expand Up @@ -1235,6 +1246,10 @@ def generate(
if not dependencies and field.dependencies:
dependencies = field.dependencies

if annotation is None:
# a place to inject
annotation = field.default_type

input_type = _cls.rule_cls.parse_annotation(
annotation=annotation,
constraints=field.constraints,
Expand Down
59 changes: 57 additions & 2 deletions utype/specs/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ipaddress import IPv4Address, IPv6Address
from typing import Optional, Type, Union, Dict
from ..utils.datastructures import unprovided
from ..utils.compat import JSON_TYPES
from enum import EnumMeta


class JsonSchemaGenerator:
Expand Down Expand Up @@ -118,6 +120,30 @@ def generate_for_type(self, t: type):
return self.generate_for_dataclass(t)
elif isinstance(t, LogicalType) and t.combinator:
return self.generate_for_logical(t)
elif isinstance(t, EnumMeta):
base = t.__base__
enum_type = None
enum_values = []
enum_map = {}
for key, val in t.__members__.items():
enum_values.append(val.value)
enum_map[key] = val.value
enum_type = type(val.value)
if not isinstance(base, EnumMeta):
enum_type = base
prim = self._get_primitive(enum_type)
fmt = self._get_format(enum_type)
data = {
"type": prim,
"enum": enum_values,
"x-annotation": {
"enums": enum_map
}
}
if fmt:
data.update(format=fmt)
return data

# default common type
prim = self._get_primitive(t)
fmt = self._get_format(t)
Expand All @@ -138,6 +164,9 @@ def generate_for_logical(self, t: LogicalType):
def _get_format(self, origin: type) -> Optional[str]:
if not origin:
return None
format = getattr(origin, 'format', None)
if format and isinstance(format, str):
return format
for types, f in self.FORMAT_MAP.items():
if issubclass(origin, types):
return f
Expand Down Expand Up @@ -289,9 +318,21 @@ def generate_for_field(self, f: ParserField, options: Options = None) -> Optiona
elif f.field.mode == 'w':
data.update(writeOnly=True)
if not unprovided(f.field.example) and f.field.example is not None:
data.update(examples=[f.field.example])
example = f.field.example
if type(f.field.example) not in JSON_TYPES:
example = str(f.field.example)
data.update(examples=[example])
if f.aliases:
data.update(aliases=list(f.aliases))
aliases = list(f.aliases)
if aliases:
# sort to stay identical
aliases.sort()
data.update(aliases=aliases)
annotations = f.schema_annotations
if annotations:
data.update({
'x-annotation': annotations
})
return data

# todo: de-duplicate generated schema class like UserSchema['a']
Expand Down Expand Up @@ -337,6 +378,12 @@ def generate_for_dataclass(self, t):
else:
data.update(additionalProperties=addition)

annotations = parser.schema_annotations
if annotations:
data.update({
'x-annotation': annotations
})

if isinstance(self.defs, dict):
return {"$ref": f"{self.ref_prefix}{self.set_def(cls_name, t, data)}"}
return data
Expand Down Expand Up @@ -372,3 +419,11 @@ def generate_for_function(self, f):
else:
data.update(additionalParameters=addition)
return data

# REVERSE ACTION OF GENERATE:
# --- GENERATE Schema and types based on Json schema


class JsonSchemaParser:
def __init__(self, json_schema: dict):
pass
1 change: 1 addition & 0 deletions utype/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"is_classvar",
"is_annotated",
"evaluate_forward_ref",
'JSON_TYPES'
]

if sys.version_info < (3, 8):
Expand Down
11 changes: 11 additions & 0 deletions utype/utils/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import TypeRegistry
import json
from .datastructures import unprovided
from ipaddress import IPv4Address, IPv6Address, IPv4Network, IPv6Network


encoder_registry = TypeRegistry('encoder', cache=True, shortcut='__encoder__')
Expand Down Expand Up @@ -98,6 +99,16 @@ def from_datetime(data: Union[datetime, date]):
return data.isoformat()


@register_encoder(IPv4Network, IPv4Address, IPv6Network, IPv6Address)
def from_ip(data):
return str(data)


@register_encoder(IPv4Network)
def from_datetime(data):
return str(data)


@register_encoder(timedelta)
def from_duration(data: timedelta):
return duration_iso_string(data)
Expand Down
Loading

0 comments on commit 764f3fe

Please sign in to comment.