diff --git a/docs/examples.rst b/docs/examples.rst index 9b59ae606..e714e0495 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -35,6 +35,7 @@ Advance Topics examples/custom-property-names examples/custom-class-factory examples/wrapped-list + examples/custom-type-mapping Test Suites diff --git a/docs/examples/custom-type-mapping.rst b/docs/examples/custom-type-mapping.rst new file mode 100644 index 000000000..85305fc91 --- /dev/null +++ b/docs/examples/custom-type-mapping.rst @@ -0,0 +1,75 @@ +=================== +Custom type mapping +=================== + +When managing a big collection of models, it sometimes is tricky to split them +into multiple python modules. Even more so if they depend on each other. For +the models to be serializable by xsdata, they need to be able to import all +other referenced models, which might not be possible due to circular imports. + +One solution to get around this problem is to fence the imports within the +python modules by using :data:`python:typing.TYPE_CHECKING` and passing a +dedicated type-map dictionary to the +:class:`~xsdata.formats.dataclass.serializers.config.SerializerConfig`. + + +.. tab:: city.py + + .. literalinclude:: /../tests/models/typemapping/city.py + :language: python + +.. tab:: street.py + + .. literalinclude:: /../tests/models/typemapping/street.py + :language: python + +.. tab:: house.py + + .. literalinclude:: /../tests/models/typemapping/house.py + :language: python + + +By fencing the imports, we are able to keep our models in different python +modules that are cleanly importable and considered valid by static type +checkers. + +Passing the type-map dictionary, which maps the class/model-names directly to +imported objects, enables xsdata to serialize the models. + + +.. testcode:: + + from xsdata.formats.dataclass.serializers import XmlSerializer + from xsdata.formats.dataclass.serializers.config import SerializerConfig + + from tests.models.typemapping.city import City + from tests.models.typemapping.house import House + from tests.models.typemapping.street import Street + + + city1 = City(name="footown") + street1 = Street(name="foostreet") + house1 = House(number=23) + city1.streets.append(street1) + street1.houses.append(house1) + + type_map = {"City": City, "Street": Street, "House": House} + serializer_config = SerializerConfig(pretty_print=True, globalns=type_map) + + xml_serializer = XmlSerializer(config=serializer_config) + serialized_house = xml_serializer.render(city1) + print(serialized_house) + + +.. testoutput:: + + + + footown + + foostreet + + 23 + + + diff --git a/tests/models/test_type_mapping.py b/tests/models/test_type_mapping.py new file mode 100644 index 000000000..f89571699 --- /dev/null +++ b/tests/models/test_type_mapping.py @@ -0,0 +1,30 @@ +from unittest import TestCase + +from tests.models.typemapping.city import City +from tests.models.typemapping.house import House +from tests.models.typemapping.street import Street +from xsdata.formats.dataclass.serializers import JsonSerializer +from xsdata.formats.dataclass.serializers import PycodeSerializer +from xsdata.formats.dataclass.serializers import XmlSerializer +from xsdata.formats.dataclass.serializers.config import SerializerConfig + + +class TypeMappingTests(TestCase): + def test_type_mapping(self): + city1 = City(name="footown") + street1 = Street(name="foostreet") + house1 = House(number=23) + city1.streets.append(street1) + street1.houses.append(house1) + + type_mapping = {"City": City, "Street": Street, "House": House} + serializer_config = SerializerConfig(globalns=type_mapping) + + json_serializer = JsonSerializer(config=serializer_config) + xml_serializer = XmlSerializer(config=serializer_config) + pycode_serializer = PycodeSerializer(config=serializer_config) + + for model in (city1, street1, house1): + json_serializer.render(model) + xml_serializer.render(model) + pycode_serializer.render(model) diff --git a/tests/models/typemapping/__init__.py b/tests/models/typemapping/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/typemapping/city.py b/tests/models/typemapping/city.py new file mode 100644 index 000000000..4c2ab3258 --- /dev/null +++ b/tests/models/typemapping/city.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import List +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tests.models.typemapping.street import Street + + +@dataclass +class City: + class Meta: + global_type = False + + name: str + streets: List["Street"] = field(default_factory=list) diff --git a/tests/models/typemapping/house.py b/tests/models/typemapping/house.py new file mode 100644 index 000000000..4574415e8 --- /dev/null +++ b/tests/models/typemapping/house.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tests.models.typemapping.street import Street + + +@dataclass +class House: + class Meta: + global_type = False + + number: int + street: Optional["Street"] = None diff --git a/tests/models/typemapping/street.py b/tests/models/typemapping/street.py new file mode 100644 index 000000000..11778f51f --- /dev/null +++ b/tests/models/typemapping/street.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import List +from typing import Optional +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from tests.models.typemapping.city import City + from tests.models.typemapping.house import House + + +@dataclass +class Street: + class Meta: + global_type = False + + name: str + city: Optional["City"] = None + houses: List["House"] = field(default_factory=list) diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index 5ace31174..39c808f07 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -174,7 +174,12 @@ def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]: return None - def build(self, clazz: Type, parent_ns: Optional[str] = None) -> XmlMeta: + def build( + self, + clazz: Type, + parent_ns: Optional[str] = None, + globalns: Optional[Dict[str, Callable]] = None, + ) -> XmlMeta: """ Fetch from cache or build the binding metadata for the given class and parent namespace. @@ -188,6 +193,7 @@ def build(self, clazz: Type, parent_ns: Optional[str] = None) -> XmlMeta: class_type=self.class_type, element_name_generator=self.element_name_generator, attribute_name_generator=self.attribute_name_generator, + globalns=globalns, ) self.cache[clazz] = builder.build(clazz, parent_ns) return self.cache[clazz] diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index d97d86ef0..754766cf7 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -40,18 +40,25 @@ class ClassMeta(NamedTuple): class XmlMetaBuilder: - __slots__ = "class_type", "element_name_generator", "attribute_name_generator" + __slots__ = ( + "class_type", + "element_name_generator", + "attribute_name_generator", + "globalns", + ) def __init__( self, class_type: ClassType, element_name_generator: Callable, attribute_name_generator: Callable, + globalns: Optional[Dict[str, Callable]] = None, ): self.class_type = class_type self.element_name_generator = element_name_generator self.attribute_name_generator = attribute_name_generator + self.globalns = globalns def build(self, clazz: Type, parent_namespace: Optional[str]) -> XmlMeta: """Build the binding metadata for a dataclass and its fields.""" @@ -112,7 +119,7 @@ def build_vars( attribute_name_generator: Callable, ): """Build the binding metadata for the given dataclass fields.""" - type_hints = get_type_hints(clazz) + type_hints = get_type_hints(clazz, globalns=self.globalns) builder = XmlVarBuilder( class_type=self.class_type, default_xml_type=self.default_xml_type(clazz), diff --git a/xsdata/formats/dataclass/serializers/config.py b/xsdata/formats/dataclass/serializers/config.py index b140d75cd..8d910b357 100644 --- a/xsdata/formats/dataclass/serializers/config.py +++ b/xsdata/formats/dataclass/serializers/config.py @@ -1,3 +1,5 @@ +from typing import Callable +from typing import Dict from typing import Optional @@ -16,6 +18,8 @@ class SerializerConfig: :param schema_location: xsi:schemaLocation attribute value :param no_namespace_schema_location: xsi:noNamespaceSchemaLocation attribute value + :param globalns: Dictionary containing global variables to extend + or overwrite for typing """ __slots__ = ( @@ -26,6 +30,7 @@ class SerializerConfig: "ignore_default_attributes", "schema_location", "no_namespace_schema_location", + "globalns", ) def __init__( @@ -37,6 +42,7 @@ def __init__( ignore_default_attributes: bool = False, schema_location: Optional[str] = None, no_namespace_schema_location: Optional[str] = None, + globalns: Optional[Dict[str, Callable]] = None, ): self.encoding = encoding self.xml_version = xml_version @@ -45,3 +51,4 @@ def __init__( self.ignore_default_attributes = ignore_default_attributes self.schema_location = schema_location self.no_namespace_schema_location = no_namespace_schema_location + self.globalns = globalns diff --git a/xsdata/formats/dataclass/serializers/json.py b/xsdata/formats/dataclass/serializers/json.py index f62e24ce3..e02d4a662 100644 --- a/xsdata/formats/dataclass/serializers/json.py +++ b/xsdata/formats/dataclass/serializers/json.py @@ -96,7 +96,9 @@ def convert(self, obj: Any, var: Optional[XmlVar] = None) -> Any: def next_value(self, obj: Any) -> Iterator[Tuple[str, Any]]: ignore_optionals = self.config.ignore_default_attributes - for var in self.context.build(obj.__class__).get_all_vars(): + for var in self.context.build( + obj.__class__, globalns=self.config.globalns + ).get_all_vars(): value = getattr(obj, var.name) if var.is_attribute and ignore_optionals and var.is_optional(value): continue diff --git a/xsdata/formats/dataclass/serializers/xml.py b/xsdata/formats/dataclass/serializers/xml.py index a13c24607..785e957f0 100644 --- a/xsdata/formats/dataclass/serializers/xml.py +++ b/xsdata/formats/dataclass/serializers/xml.py @@ -78,7 +78,9 @@ def write_object(self, obj: Any): """Produce an events stream from a dataclass or a derived element.""" qname = xsi_type = None if isinstance(obj, self.context.class_type.derived_element): - meta = self.context.build(obj.value.__class__) + meta = self.context.build( + obj.value.__class__, globalns=self.config.globalns + ) qname = obj.qname obj = obj.value xsi_type = namespaces.real_xsi_type(qname, meta.target_qname) @@ -99,8 +101,9 @@ def write_dataclass( Optionally override the qualified name and the xsi properties type and nil. """ - - meta = self.context.build(obj.__class__, namespace) + meta = self.context.build( + obj.__class__, namespace, globalns=self.config.globalns + ) qname = qname or meta.qname nillable = nillable or meta.nillable namespace, tag = namespaces.split_qname(qname)