Skip to content

Commit

Permalink
chore: Code cleanup (#1038)
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra authored May 11, 2024
1 parent 3713c6a commit 584374d
Show file tree
Hide file tree
Showing 33 changed files with 61 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: typos
exclude: ^tests/|.xsd|xsdata/models/datatype.py$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.3
rev: v0.4.4
hooks:
- id: ruff
args: [ --fix, --show-fixes]
Expand Down
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
[![image](https://github.com/tefra/xsdata/workflows/tests/badge.svg)](https://github.com/tefra/xsdata/actions)
[![image](https://readthedocs.org/projects/xsdata/badge)](https://xsdata.readthedocs.io/)
[![image](https://codecov.io/gh/tefra/xsdata/branch/main/graph/badge.svg)](https://codecov.io/gh/tefra/xsdata)
[![image](https://img.shields.io/github/languages/top/tefra/xsdata.svg)](https://xsdata.readthedocs.io/)
[![image](https://www.codefactor.io/repository/github/tefra/xsdata/badge)](https://www.codefactor.io/repository/github/tefra/xsdata)
[![image](https://img.shields.io/pypi/pyversions/xsdata.svg)](https://pypi.org/pypi/xsdata/)
[![image](https://img.shields.io/pypi/v/xsdata.svg)](https://pypi.org/pypi/xsdata/)
Expand Down
4 changes: 2 additions & 2 deletions tests/codegen/handlers/test_create_compound_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_group_fields(self):
)
expected_res = Restrictions(min_occurs=0, max_occurs=20)

self.processor.group_fields(target, list(target.attrs))
self.processor.group_fields(target, target.attrs.copy())
self.assertEqual(1, len(target.attrs))
self.assertEqual(expected, target.attrs[0])
self.assertEqual(expected_res, target.attrs[0].restrictions)
Expand All @@ -134,7 +134,7 @@ def test_group_fields_with_effective_choices_sums_occurs(self):

expected_res = Restrictions(min_occurs=4, max_occurs=6)

self.processor.group_fields(target, list(target.attrs))
self.processor.group_fields(target, target.attrs.copy())
self.assertEqual(1, len(target.attrs))
self.assertEqual(expected_res, target.attrs[0].restrictions)

Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_copy_group_attributes(self, mock_clone_attribute, mock_copy_inner_class
source = ClassFactory.elements(2)
source.inner.append(ClassFactory.create())
target = ClassFactory.elements(3)
attrs = list(target.attrs)
attrs = target.attrs.copy()
attrs[1].name = "bar"
attr = target.attrs[1]

Expand Down
7 changes: 5 additions & 2 deletions xsdata/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def generate(**kwargs: Any):
handler.emit_warnings()


_SUPPORTED_EXTENSIONS = ("wsdl", "xsd", "dtd", "xml", "json")


def resolve_source(source: str, recursive: bool) -> Iterator[str]:
"""Yields all supported resource URIs."""
if source.find("://") > -1 and not source.startswith("file://"):
Expand All @@ -151,9 +154,9 @@ def resolve_source(source: str, recursive: bool) -> Iterator[str]:
path = Path(source).resolve()
match = "**/*" if recursive else "*"
if path.is_dir():
for ext in ["wsdl", "xsd", "dtd", "xml", "json"]:
for ext in _SUPPORTED_EXTENSIONS:
yield from (x.as_uri() for x in path.glob(f"{match}.{ext}"))
else: # is file
else: # is a file
yield path.as_uri()


Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/add_attribute_substitutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def process(self, target: Class):
if self.substitutions is None:
self.create_substitutions()

for attr in list(target.attrs):
for attr in target.attrs.copy():
if not (attr.is_enumeration or attr.is_wildcard):
self.process_attribute(target, attr)

Expand Down
10 changes: 6 additions & 4 deletions xsdata/codegen/handlers/detect_circular_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,18 @@ def is_circular(self, start: int, stop: int) -> bool:
"""
path = set()
stack = [start]
while len(stack) != 0:
while stack:
if stop in path:
return True

ref = stack.pop()
path.add(ref)

for tp in self.reference_types[ref]:
if not tp.circular and tp.reference not in path:
stack.append(tp.reference)
stack.extend(
tp.reference
for tp in self.reference_types[ref]
if not tp.circular and tp.reference not in path
)

return stop in path

Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/flatten_attribute_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def process(self, target: Class):
target: The target class instance to inspect and process
"""
repeat = False
for attr in list(target.attrs):
for attr in target.attrs.copy():
if attr.is_group:
repeat = True
self.process_attribute(target, attr)
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/flatten_class_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def process(self, target: Class):
Args:
target: The target class instance
"""
for extension in list(target.extensions):
for extension in target.extensions.copy():
self.process_extension(target, extension)

def process_extension(self, target: Class, extension: Extension):
Expand Down
4 changes: 2 additions & 2 deletions xsdata/codegen/handlers/process_attributes_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def process(self, target: Class):
Args:
target: The target class instance
"""
for attr in list(target.attrs):
for attr in target.attrs.copy():
self.process_types(target, attr)
self.cascade_properties(target, attr)

Expand All @@ -46,7 +46,7 @@ def process_types(self, target: Class, attr: Attr):
if self.container.config.output.ignore_patterns:
attr.restrictions.pattern = None

for attr_type in list(attr.types):
for attr_type in attr.types.copy():
self.process_type(target, attr, attr_type)

attr.types = ClassUtils.filter_types(attr.types)
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/process_mixed_content_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def process(self, target: Class):

attrs = []
choices = []
for attr in list(target.attrs):
for attr in target.attrs.copy():
if attr.is_attribute:
attrs.append(attr)
elif not attr.is_any_type:
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/unnest_inner_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def process(self, target: Class):
Args:
target: The target class instance to inspect
"""
for inner in list(target.inner):
for inner in target.inner.copy():
if inner.is_enumeration or self.container.config.output.unnest_classes:
self.promote(target, inner)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,4 @@ def group_repeating_attrs(cls, target: Class) -> List[List[int]]:
if not attr.is_attribute:
counters[attr.key].append(index)

groups = []
for x in counters.values():
if len(x) > 1:
groups.append(list(range(x[0], x[-1] + 1)))

return groups
return [list(range(x[0], x[-1] + 1)) for x in counters.values() if len(x) > 1]
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/vacuum_inner_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def process(self, target: Class):
target: The target class instance
"""
target.inner = collections.unique_sequence(target.inner, key="qname")
for inner in list(target.inner):
for inner in target.inner.copy():
if not inner.attrs and len(inner.extensions) < 2:
self.remove_inner(target, inner)
elif inner.qname == target.qname:
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/handlers/validate_attributes_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def validate_attrs(cls, target: Class, base_attrs_map: Dict[str, List[Attr]]):
target: The target class instance
base_attrs_map: A mapping of qualified names to lists of parent attrs
"""
for attr in list(target.attrs):
for attr in target.attrs.copy():
base_attrs = base_attrs_map.get(attr.slug)

if base_attrs:
Expand Down
9 changes: 4 additions & 5 deletions xsdata/codegen/mappers/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ def group_repeating_attrs(cls, element: AnyElement) -> List[List[int]]:
if isinstance(child, AnyElement) and child.qname:
counters[child.qname].append(index)

groups = []
groups: List[List[int]] = []
if len(counters) > 1:
for x in counters.values():
if len(x) > 1:
groups.append(list(range(x[0], x[-1] + 1)))

groups.extend(
list(range(x[0], x[-1] + 1)) for x in counters.values() if len(x) > 1
)
return groups
2 changes: 1 addition & 1 deletion xsdata/codegen/mappers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def build_attr_types(cls, target: Class, obj: ElementBase) -> List[AttrType]:
target.inner.append(inner)
types.append(AttrType(qname=inner.qname, forward=True))

if len(types) == 0:
if not types:
types.append(cls.build_attr_type(target, name=obj.default_type))

return collections.unique_sequence(types)
Expand Down
3 changes: 1 addition & 2 deletions xsdata/codegen/mixins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
from abc import ABCMeta
from typing import Callable, Dict, Iterator, List, Optional

from xsdata.codegen.models import Attr, Class
Expand Down Expand Up @@ -129,7 +128,7 @@ def process(self, target: Class):
"""


class RelativeHandlerInterface(HandlerInterface, metaclass=ABCMeta):
class RelativeHandlerInterface(HandlerInterface, abc.ABC):
"""An interface for codegen handlers with class container access.
Args:
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/parsers/dtd.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def build_ns_map(cls, prefix: str, attributes: List[DtdAttribute]) -> Dict:
"""
ns_map = {ns.prefix: ns.uri for ns in Namespace.common()}

for attribute in list(attributes):
for attribute in attributes.copy():
if not attribute.default_value:
continue

Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def convert_definitions(self, definitions: Definitions):
def generate_classes(self, schema: Schema) -> List[Class]:
"""Convert the given schema instance to a list of classes."""
uri = schema.location
logger.info("Compiling schema %s", uri if uri else "...")
logger.info("Compiling schema %s", uri or "...")
classes = SchemaMapper.map(schema)

class_num, inner_num = self.count_classes(classes)
Expand Down
2 changes: 1 addition & 1 deletion xsdata/codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def clean_inner_classes(cls, target: Class):
Args:
target: The target class instance to inspect.
"""
for inner in list(target.inner):
for inner in target.inner.copy():
if cls.is_orphan_inner(target, inner):
target.inner.remove(inner)

Expand Down
4 changes: 2 additions & 2 deletions xsdata/codegen/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def is_invalid(ext: Extension) -> bool:
"""Check if given type declaration is not native and is missing."""
return not ext.type.native and ext.type.qname not in self.container.data

for target in list(classes):
for target in classes.copy():
if any(is_invalid(extension) for extension in target.extensions):
classes.remove(target)

Expand All @@ -77,7 +77,7 @@ def handle_duplicate_types(cls, classes: List[Class]):
if len(items) == 1:
continue

index = cls.select_winner(list(items))
index = cls.select_winner(items.copy())

if index == -1:
logger.warning(
Expand Down
13 changes: 5 additions & 8 deletions xsdata/formats/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import binascii
import math
import re
from contextlib import suppress
from datetime import date, datetime, time
from decimal import Decimal, InvalidOperation
from enum import Enum, EnumMeta
Expand Down Expand Up @@ -99,11 +100,9 @@ def deserialize(self, value: Any, types: Sequence[Type], **kwargs: Any) -> Any:
The converted value
"""
for data_type in types:
try:
with suppress(ConverterError):
instance = self.type_converter(data_type)
return instance.deserialize(value, data_type=data_type, **kwargs)
except ConverterError:
pass

type_names = " | ".join(tp.__name__ for tp in types)
raise ConverterError(f"`{value}` is not a valid `{type_names}`")
Expand Down Expand Up @@ -205,11 +204,9 @@ def type_converter(self, data_type: Type) -> Converter:
Returns:
A converter instance
"""
try:
with suppress(KeyError):
# Quick in and out, without checking the whole mro.
return self.registry[data_type]
except KeyError:
pass

# We tested the first, ignore the object
for mro in data_type.__mro__[1:-1]:
Expand Down Expand Up @@ -677,7 +674,7 @@ def match(
Whether the value or values matches the enumeration member value.
"""
if isinstance(value, str) and isinstance(real, str):
return value == real or " ".join(values) == real
return real in (value, " ".join(values))

if isinstance(real, (tuple, list)) and not hasattr(real, "_fields"):
if len(real) == length and cls._match_list(values, real, **kwargs):
Expand Down Expand Up @@ -708,7 +705,7 @@ def _match_atomic(cls, raw: Any, real: Any, **kwargs: Any) -> bool:
return cmp == real


class DateTimeBase(Converter, metaclass=abc.ABCMeta):
class DateTimeBase(Converter, abc.ABC):
"""An abstract datetime converter."""

@classmethod
Expand Down
7 changes: 3 additions & 4 deletions xsdata/formats/dataclass/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
from collections import defaultdict
from contextlib import suppress
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Type

from xsdata.exceptions import XmlContextError
Expand Down Expand Up @@ -192,7 +193,7 @@ def get_field_diff(clazz: Type) -> int:
]

choices.sort(key=lambda x: (x[1], x[0].__name__))
return choices[0][0] if len(choices) > 0 else None
return choices[0][0] if choices else None

def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]:
"""Find a subclass for the given clazz and xsi:type qname.
Expand Down Expand Up @@ -301,9 +302,7 @@ def is_derived(cls, obj: Any, clazz: Type) -> bool:
@classmethod
def get_subclasses(cls, clazz: Type) -> Iterator[Type]:
"""Return an iterator of the given class subclasses."""
try:
with suppress(TypeError):
for subclass in clazz.__subclasses__():
yield from cls.get_subclasses(subclass)
yield subclass
except TypeError:
pass
8 changes: 3 additions & 5 deletions xsdata/formats/dataclass/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def class_name(self, name: str) -> str:

def class_bases(self, obj: Class, class_name: str) -> List[str]:
"""Return a list of base class names."""
bases = []
for obj_ext in obj.extensions:
bases.append(self.type_name(obj_ext.type))
bases = [self.type_name(x.type) for x in obj.extensions]

derived = len(obj.extensions) > 0
for ext in self.extensions[ExtensionType.CLASS]:
Expand Down Expand Up @@ -238,7 +236,7 @@ def class_annotations(self, obj: Class, class_name: str) -> List[str]:
def apply_substitutions(self, name: str, obj_type: ObjectType) -> str:
"""Apply name substitutions by obj type."""
for search, replace in self.substitutions[obj_type].items():
name = re.sub(rf"{search}", rf"{replace}", name)
name = re.sub(search, replace, name)

return name

Expand Down Expand Up @@ -573,7 +571,7 @@ def format_string(self, data: str, indent: int, key: str = "", pad: int = 0) ->
if key == "pattern":
# escape double quotes because double quotes surround the regex string
# in the rendered output
value = re.sub(self.UNESCAPED_DBL_QUOTE_REGEX, r'\1\\"', data)
value = self.UNESCAPED_DBL_QUOTE_REGEX.sub(r'\1\\"', data)
return f'r"{value}"'

if data == "":
Expand Down
3 changes: 1 addition & 2 deletions xsdata/formats/dataclass/models/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ def _match_namespace(self, qname: str) -> bool:
for check in self.namespaces:
if (
(not check and uri is None)
or check == uri
or check == NamespaceType.ANY_NS
or check in (uri, NamespaceType.ANY_NS)
or (check and check[0] == "!" and check[1:] != uri)
):
return True
Expand Down
Loading

0 comments on commit 584374d

Please sign in to comment.