diff --git a/litestar/dto/_backend.py b/litestar/dto/_backend.py index f98b4607c6..c3de87e9e1 100644 --- a/litestar/dto/_backend.py +++ b/litestar/dto/_backend.py @@ -3,7 +3,6 @@ """ from __future__ import annotations -import secrets from dataclasses import replace from typing import TYPE_CHECKING, AbstractSet, Any, Callable, ClassVar, Collection, Final, Mapping, Union, cast @@ -27,6 +26,7 @@ from litestar.serialization import decode_json, decode_msgpack from litestar.types import Empty from litestar.typing import FieldDefinition +from litestar.utils import unique_name_for_scope from litestar.utils.typing import safe_generic_origin_map if TYPE_CHECKING: @@ -154,6 +154,18 @@ def parse_model( defined_fields.append(transfer_field_definition) return tuple(defined_fields) + def _create_transfer_model_name(self, model_name: str) -> str: + long_name_prefix = self.handler_id.split("::")[0] + short_name_prefix = _camelize(long_name_prefix.split(".")[-1], True) + + name_suffix = "RequestBody" if self.is_data_field else "ResponseBody" + + if (short_name := f"{short_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names: + return short_name + if (long_name := f"{long_name_prefix}{model_name}{name_suffix}") not in self._seen_model_names: + return long_name + return unique_name_for_scope(long_name, self._seen_model_names) + def create_transfer_model_type( self, model_name: str, field_definitions: tuple[TransferDTOFieldDefinition, ...] ) -> type[Struct]: @@ -166,19 +178,9 @@ def create_transfer_model_type( Returns: A ``BackendT`` class. """ - long_name_prefix = self.handler_id.split("::")[0] - short_name_prefix = _camelize(long_name_prefix.split(".")[-1], True) - - name_suffix = "RequestBody" if self.is_data_field else "ResponseBody" - - if f"{short_name_prefix}{model_name}{name_suffix}" not in self._seen_model_names: - struct_name = f"{short_name_prefix}{model_name}{name_suffix}" - elif f"{long_name_prefix}{model_name}{name_suffix}" not in self._seen_model_names: - struct_name = f"{long_name_prefix}{model_name}{name_suffix}" - else: - struct_name = f"{long_name_prefix}{model_name}{name_suffix}{secrets.token_hex(8)}" - + struct_name = self._create_transfer_model_name(model_name) self._seen_model_names.add(struct_name) + struct = _create_struct_for_field_definitions(struct_name, field_definitions) setattr(struct, "__schema_name__", struct_name) return struct diff --git a/litestar/utils/__init__.py b/litestar/utils/__init__.py index 3714ce8bfe..a59f56268f 100644 --- a/litestar/utils/__init__.py +++ b/litestar/utils/__init__.py @@ -1,6 +1,6 @@ from litestar.utils.deprecation import deprecated, warn_deprecation -from .helpers import Ref, get_enum_string_value, get_name, url_quote +from .helpers import Ref, get_enum_string_value, get_name, unique_name_for_scope, url_quote from .path import join_paths, normalize_path from .predicates import ( is_annotated_type, @@ -75,6 +75,7 @@ "normalize_path", "set_litestar_scope_state", "unique", + "unique_name_for_scope", "url_quote", "warn_deprecation", ) diff --git a/litestar/utils/helpers.py b/litestar/utils/helpers.py index d26602fe0d..a917051e3b 100644 --- a/litestar/utils/helpers.py +++ b/litestar/utils/helpers.py @@ -6,6 +6,7 @@ from urllib.parse import quote if TYPE_CHECKING: + from collections.abc import Container from typing import Iterable from litestar.datastructures import Cookie @@ -18,6 +19,7 @@ "get_name", "unwrap_partial", "url_quote", + "unique_name_for_scope", ) T = TypeVar("T") @@ -99,3 +101,12 @@ def url_quote(value: str | bytes) -> str: A quoted URL. """ return quote(value, safe="/#%[]=:;$&()+,!?*@'~") + + +def unique_name_for_scope(base_name: str, scope: Container[str]) -> str: + """Create a name derived from ``base_name`` that's unique within ``scope``""" + i = 0 + while True: + if (unique_name := f"{base_name}_{i}") not in scope: + return unique_name + i += 1 diff --git a/tests/unit/test_utils/test_helpers.py b/tests/unit/test_utils/test_helpers.py index f636d4b033..c8995aaf2f 100644 --- a/tests/unit/test_utils/test_helpers.py +++ b/tests/unit/test_utils/test_helpers.py @@ -1,6 +1,6 @@ from functools import partial -from litestar.utils.helpers import unwrap_partial +from litestar.utils.helpers import unique_name_for_scope, unwrap_partial def test_unwrap_partial() -> None: @@ -11,3 +11,11 @@ def func(*args: int) -> int: assert wrapped() == 3 assert unwrap_partial(wrapped) is func + + +def test_unique_name_for_scope() -> None: + assert unique_name_for_scope("a", []) == "a_0" + + assert unique_name_for_scope("a", ["a", "a_0", "b"]) == "a_1" + + assert unique_name_for_scope("b", ["a", "a_0", "b"]) == "b_0"