From 5c4da0aa3a5ee67cfdc3aed0c2c2a7ee7cc99d91 Mon Sep 17 00:00:00 2001 From: Roman Right Date: Wed, 28 Feb 2024 09:06:20 -0600 Subject: [PATCH] Sync with beanie 24 02 (#20) * the logic is synced * tests synced * version: 1.3.0 * fix tests --- bunnet/__init__.py | 9 +- bunnet/exceptions.py | 4 + bunnet/executors/migrate.py | 18 +- bunnet/migrations/controllers/base.py | 3 + bunnet/migrations/controllers/iterative.py | 4 +- bunnet/migrations/runner.py | 78 ++++-- bunnet/odm/documents.py | 245 ++++++++---------- bunnet/odm/fields.py | 36 ++- bunnet/odm/interfaces/find.py | 104 +++++--- bunnet/odm/interfaces/update.py | 6 +- bunnet/odm/operators/update/array.py | 6 +- bunnet/odm/operators/update/general.py | 10 +- bunnet/odm/queries/delete.py | 3 +- bunnet/odm/queries/find.py | 126 ++++++--- bunnet/odm/queries/update.py | 4 +- bunnet/odm/settings/document.py | 3 + bunnet/odm/settings/timeseries.py | 10 +- bunnet/odm/settings/view.py | 5 + bunnet/odm/utils/encoder.py | 21 +- bunnet/odm/utils/find.py | 96 ++++++- bunnet/odm/utils/init.py | 115 ++++---- bunnet/odm/utils/parsing.py | 54 +++- bunnet/odm/utils/pydantic.py | 6 +- bunnet/odm/utils/relations.py | 15 +- bunnet/odm/utils/state.py | 10 - bunnet/odm/utils/typing.py | 58 ++++- docs/changelog.md | 8 + pyproject.toml | 5 +- .../break/20210413211219_break.py | 4 +- tests/migrations/test_break.py | 2 +- tests/migrations/test_free_fall.py | 23 ++ tests/odm/conftest.py | 16 ++ tests/odm/documents/test_inheritance.py | 1 - tests/odm/documents/test_init.py | 46 ++++ tests/odm/documents/test_revision.py | 34 ++- tests/odm/documents/test_sync.py | 54 ++++ tests/odm/documents/test_update.py | 2 + .../odm/documents/test_validation_on_save.py | 29 ++- tests/odm/models.py | 116 +++++++-- tests/odm/query/test_aggregate.py | 83 +++++- tests/odm/query/test_delete.py | 9 +- tests/odm/query/test_find.py | 17 ++ tests/odm/test_cursor.py | 2 +- tests/odm/test_encoder.py | 17 ++ tests/odm/test_fields.py | 29 ++- tests/odm/test_relations.py | 210 ++++++++++++++- tests/odm/test_state_management.py | 36 ++- tests/odm/test_typing_utils.py | 29 ++- 48 files changed, 1395 insertions(+), 426 deletions(-) create mode 100644 tests/odm/documents/test_sync.py diff --git a/bunnet/__init__.py b/bunnet/__init__.py index c7742df..1b2f2fb 100644 --- a/bunnet/__init__.py +++ b/bunnet/__init__.py @@ -16,9 +16,11 @@ from bunnet.odm.bulk import BulkWriter from bunnet.odm.custom_types import DecimalAnnotation from bunnet.odm.custom_types.bson.binary import BsonBinary -from bunnet.odm.documents import Document +from bunnet.odm.documents import Document, MergeStrategy +from bunnet.odm.enums import SortDirection from bunnet.odm.fields import ( BackLink, + BunnetObjectId, DeleteRules, Indexed, Link, @@ -31,7 +33,7 @@ from bunnet.odm.utils.init import init_bunnet from bunnet.odm.views import View -__version__ = "1.2.0" +__version__ = "1.3.0" __all__ = [ # ODM "Document", @@ -39,9 +41,12 @@ "UnionDoc", "init_bunnet", "PydanticObjectId", + "BunnetObjectId", "Indexed", "TimeSeriesConfig", "Granularity", + "SortDirection", + "MergeStrategy", # Actions "before_event", "after_event", diff --git a/bunnet/exceptions.py b/bunnet/exceptions.py index 4e22b78..9273994 100644 --- a/bunnet/exceptions.py +++ b/bunnet/exceptions.py @@ -68,3 +68,7 @@ class DocWasNotRegisteredInUnionClass(Exception): class Deprecation(Exception): pass + + +class ApplyChangesException(Exception): + pass diff --git a/bunnet/executors/migrate.py b/bunnet/executors/migrate.py index 16940e9..bd98d78 100644 --- a/bunnet/executors/migrate.py +++ b/bunnet/executors/migrate.py @@ -52,6 +52,7 @@ def __init__(self, **kwargs): or self.get_from_toml("allow_index_dropping") or False ) + self.use_transaction = bool(kwargs.get("use_transaction")) @staticmethod def get_env_value(field_name) -> Any: @@ -109,7 +110,11 @@ def run_migrate(settings: MigrationSettings): mode = RunningMode( direction=settings.direction, distance=settings.distance ) - root.run(mode=mode, allow_index_dropping=settings.allow_index_dropping) + root.run( + mode=mode, + allow_index_dropping=settings.allow_index_dropping, + use_transaction=settings.use_transaction, + ) @migrations.command() @@ -157,6 +162,15 @@ def run_migrate(settings: MigrationSettings): default=False, help="if allow-index-dropping is set, Beanie will drop indexes from your collection", ) +@click.option( + "--use-transaction/--no-use-transaction", + required=False, + default=True, + help="Enable or disable the use of transactions during migration. " + "When enabled (--use-transaction), Bunnet uses transactions for migration, " + "which necessitates a replica set. When disabled (--no-use-transaction), " + "migrations occur without transactions.", +) def migrate( direction, distance, @@ -164,6 +178,7 @@ def migrate( database_name, path, allow_index_dropping, + use_transaction, ): settings_kwargs = {} if direction: @@ -178,6 +193,7 @@ def migrate( settings_kwargs["path"] = path if allow_index_dropping: settings_kwargs["allow_index_dropping"] = allow_index_dropping + settings_kwargs["use_transaction"] = use_transaction settings = MigrationSettings(**settings_kwargs) run_migrate(settings) diff --git a/bunnet/migrations/controllers/base.py b/bunnet/migrations/controllers/base.py index ed4f520..3414de8 100644 --- a/bunnet/migrations/controllers/base.py +++ b/bunnet/migrations/controllers/base.py @@ -5,6 +5,9 @@ class BaseMigrationController(ABC): + def __init__(self, function): + self.function = function + @abstractmethod def run(self, session): pass diff --git a/bunnet/migrations/controllers/iterative.py b/bunnet/migrations/controllers/iterative.py index fb318d8..f5d6ebe 100644 --- a/bunnet/migrations/controllers/iterative.py +++ b/bunnet/migrations/controllers/iterative.py @@ -91,7 +91,9 @@ def models(self) -> List[Type[Document]]: def run(self, session): output_documents = [] - for input_document in self.input_document_model.find_all(): + for input_document in self.input_document_model.find_all( + session=session + ): output = DummyOutput() function_kwargs = { "input_document": input_document, diff --git a/bunnet/migrations/runner.py b/bunnet/migrations/runner.py index 6c3a9b6..1a5b5df 100644 --- a/bunnet/migrations/runner.py +++ b/bunnet/migrations/runner.py @@ -1,7 +1,10 @@ import logging from importlib.machinery import SourceFileLoader from pathlib import Path -from typing import Optional, Type +from typing import List, Optional, Type + +from pymongo.client_session import ClientSession +from pymongo.database import Database from bunnet.migrations.controllers.iterative import BaseMigrationController from bunnet.migrations.database import DBHandler @@ -55,7 +58,12 @@ def update_current_migration(self): self.clean_current_migration() MigrationLog(is_current=True, name=self.name).insert() - def run(self, mode: RunningMode, allow_index_dropping: bool): + def run( + self, + mode: RunningMode, + allow_index_dropping: bool, + use_transaction: bool, + ): """ Migrate @@ -71,7 +79,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool): logger.info("Running migrations forward without limit") while True: migration_node.run_forward( - allow_index_dropping=allow_index_dropping + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) migration_node = migration_node.next_migration if migration_node is None: @@ -80,7 +89,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool): logger.info(f"Running {mode.distance} migrations forward") for i in range(mode.distance): migration_node.run_forward( - allow_index_dropping=allow_index_dropping + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) migration_node = migration_node.next_migration if migration_node is None: @@ -91,7 +101,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool): logger.info("Running migrations backward without limit") while True: migration_node.run_backward( - allow_index_dropping=allow_index_dropping + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) migration_node = migration_node.prev_migration if migration_node is None: @@ -100,30 +111,37 @@ def run(self, mode: RunningMode, allow_index_dropping: bool): logger.info(f"Running {mode.distance} migrations backward") for i in range(mode.distance): migration_node.run_backward( - allow_index_dropping=allow_index_dropping + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) migration_node = migration_node.prev_migration if migration_node is None: break - def run_forward(self, allow_index_dropping): + def run_forward(self, allow_index_dropping: bool, use_transaction: bool): if self.forward_class is not None: self.run_migration_class( - self.forward_class, allow_index_dropping=allow_index_dropping + self.forward_class, + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) self.update_current_migration() - def run_backward(self, allow_index_dropping): + def run_backward(self, allow_index_dropping: bool, use_transaction: bool): if self.backward_class is not None: self.run_migration_class( - self.backward_class, allow_index_dropping=allow_index_dropping + self.backward_class, + allow_index_dropping=allow_index_dropping, + use_transaction=use_transaction, ) if self.prev_migration is not None: self.prev_migration.update_current_migration() else: self.clean_current_migration() - def run_migration_class(self, cls: Type, allow_index_dropping: bool): + def run_migration_class( + self, cls: Type, allow_index_dropping: bool, use_transaction: bool + ): """ Run Backward or Forward migration class @@ -142,19 +160,33 @@ def run_migration_class(self, cls: Type, allow_index_dropping: bool): if client is None: raise RuntimeError("client must not be None") with client.start_session() as s: - with s.start_transaction(): - for migration in migrations: - for model in migration.models: - init_bunnet( - database=db, - document_models=[model], # type: ignore - allow_index_dropping=allow_index_dropping, - ) # TODO this is slow - logger.info( - f"Running migration {migration.function.__name__} " - f"from module {self.name}" + if use_transaction: + with s.start_transaction(): + self.run_migrations( + migrations, db, allow_index_dropping, s ) - migration.run(session=s) + else: + self.run_migrations(migrations, db, allow_index_dropping, s) + + def run_migrations( + self, + migrations: List[BaseMigrationController], + db: Database, + allow_index_dropping: bool, + session: ClientSession, + ) -> None: + for migration in migrations: + for model in migration.models: + init_bunnet( + database=db, + document_models=[model], # type: ignore + allow_index_dropping=allow_index_dropping, + ) # TODO this is slow + logger.info( + f"Running migration {migration.function.__name__} " + f"from module {self.name}" + ) + migration.run(session=session) @classmethod def build(cls, path: Path): diff --git a/bunnet/odm/documents.py b/bunnet/odm/documents.py index 57bfe62..35bc522 100644 --- a/bunnet/odm/documents.py +++ b/bunnet/odm/documents.py @@ -1,6 +1,6 @@ +import warnings +from enum import Enum from typing import ( - TYPE_CHECKING, - AbstractSet, Any, ClassVar, Dict, @@ -8,7 +8,6 @@ List, Mapping, Optional, - Set, Type, TypeVar, Union, @@ -82,7 +81,7 @@ from bunnet.odm.queries.update import UpdateMany, UpdateResponse from bunnet.odm.settings.document import DocumentSettings from bunnet.odm.utils.dump import get_dict, get_top_level_nones -from bunnet.odm.utils.parsing import merge_models +from bunnet.odm.utils.parsing import apply_changes, merge_models from bunnet.odm.utils.pydantic import ( IS_PYDANTIC_V2, get_extra_field_info, @@ -97,26 +96,28 @@ previous_saved_state_needed, save_state_after, saved_state_needed, - swap_revision_after, ) from bunnet.odm.utils.typing import extract_id_class if IS_PYDANTIC_V2: from pydantic import model_validator -if TYPE_CHECKING: - from pydantic.typing import AbstractSetIntStr, DictStrAny, MappingIntStrAny - DocType = TypeVar("DocType", bound="Document") DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel) def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None: - props = {} - for k, v in schema.get("properties", {}).items(): - if not v.get("hidden", False): - props[k] = v - schema["properties"] = props + # remove excluded fields from the json schema + properties = schema.get("properties") + if not properties: + return + for k, field in get_model_fields(model).items(): + k = field.alias or k + if k not in properties: + continue + field_info = field if IS_PYDANTIC_V2 else field.field_info + if field_info.exclude: + del properties[k] def document_alias_generator(s: str) -> str: @@ -125,6 +126,11 @@ def document_alias_generator(s: str) -> str: return s +class MergeStrategy(str, Enum): + local = "local" + remote = "remote" + + class Document( LazyModel, SettersInterface, @@ -151,34 +157,17 @@ class Document( else: class Config: - json_encoders = { - ObjectId: lambda v: str(v), - } + json_encoders = {ObjectId: str} allow_population_by_field_name = True fields = {"id": "_id"} - - @staticmethod - def schema_extra( - schema: Dict[str, Any], model: Type["Document"] - ) -> None: - props = {} - for k, v in schema.get("properties", {}).items(): - if not v.get("hidden", False): - props[k] = v - schema["properties"] = props + schema_extra = staticmethod(json_schema_extra) id: Optional[PydanticObjectId] = Field( default=None, description="MongoDB document ObjectID" ) # State - if IS_PYDANTIC_V2: - revision_id: Optional[UUID] = Field( - default=None, json_schema_extra={"hidden": True} - ) - else: - revision_id: Optional[UUID] = Field(default=None, hidden=True) # type: ignore - _previous_revision_id: Optional[UUID] = PrivateAttr(default=None) + revision_id: Optional[UUID] = Field(default=None, exclude=True) _saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None) _previous_saved_state: Optional[Dict[str, Any]] = PrivateAttr(default=None) @@ -194,15 +183,7 @@ def schema_extra( # Database _database_major_version: ClassVar[int] = 4 - # Other - _hidden_fields: ClassVar[Set[str]] = set() - - def _swap_revision(self): - if self.get_settings().use_revision: - self._previous_revision_id = self.revision_id - self.revision_id = uuid4() - - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(Document, self).__init__(*args, **kwargs) self.get_motor_collection() @@ -250,6 +231,8 @@ def get( ignore_cache: bool = False, fetch_links: bool = False, with_children: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> Optional["DocType"]: """ @@ -275,11 +258,47 @@ def get( ignore_cache=ignore_cache, fetch_links=fetch_links, with_children=with_children, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) + def sync(self, merge_strategy: MergeStrategy = MergeStrategy.remote): + """ + Sync the document with the database + + :param merge_strategy: MergeStrategy - how to merge the document + :return: None + """ + if ( + merge_strategy == MergeStrategy.local + and self.get_settings().use_state_management is False + ): + raise ValueError( + "State management must be turned on to use local merge strategy" + ) + if self.id is None: + raise DocumentWasNotSaved + document = self.find_one({"_id": self.id}).run() + if document is None: + raise DocumentNotFound + + if merge_strategy == MergeStrategy.local: + original_changes = self.get_changes() + new_state = document.get_saved_state() + if new_state is None: + raise DocumentWasNotSaved + changes_to_apply = self._collect_updates( + new_state, original_changes + ) + merge_models(self, document) + apply_changes(changes_to_apply, self) + elif merge_strategy == MergeStrategy.remote: + merge_models(self, document) + else: + raise ValueError("Invalid merge strategy") + @wrap_with_actions(EventTypes.INSERT) - @swap_revision_after @save_state_after @validate_self_before def insert( @@ -414,7 +433,6 @@ def insert_many( ) @wrap_with_actions(EventTypes.REPLACE) - @swap_revision_after @save_state_after @validate_self_before def replace( @@ -478,7 +496,8 @@ def replace( find_query: Dict[str, Any] = {"_id": self.id} if use_revision_id and not ignore_revision: - find_query["revision_id"] = self._previous_revision_id + find_query["revision_id"] = self.revision_id + self.revision_id = uuid4() try: self.find_one(find_query).replace_one( self, @@ -662,10 +681,11 @@ def update( find_query = {"_id": PydanticObjectId()} if use_revision_id and not ignore_revision: - find_query["revision_id"] = self._previous_revision_id + find_query["revision_id"] = self.revision_id if use_revision_id: - arguments.append(SetRevisionId(self.revision_id)) + new_revision_id = uuid4() + arguments.append(SetRevisionId(new_revision_id)) try: result = ( self.find_one(find_query) @@ -926,7 +946,7 @@ def _save_state(self) -> None: self, to_db=True, keep_nulls=self.get_settings().keep_nulls, - exclude={"revision_id", "_previous_revision_id"}, + exclude={"revision_id"}, ) def get_saved_state(self) -> Optional[Dict[str, Any]]: @@ -950,7 +970,7 @@ def is_changed(self) -> bool: self, to_db=True, keep_nulls=self.get_settings().keep_nulls, - exclude={"revision_id", "_previous_revision_id"}, + exclude={"revision_id"}, ): return False return True @@ -1008,7 +1028,13 @@ def _collect_updates( @saved_state_needed def get_changes(self) -> Dict[str, Any]: return self._collect_updates( - self._saved_state, get_dict(self, to_db=True, keep_nulls=self.get_settings().keep_nulls) # type: ignore + self._saved_state, # type: ignore + get_dict( + self, + to_db=True, + keep_nulls=self.get_settings().keep_nulls, + exclude={"revision_id"}, + ), ) @saved_state_needed @@ -1018,7 +1044,8 @@ def get_previous_changes(self) -> Dict[str, Any]: return {} return self._collect_updates( - self._previous_saved_state, self._saved_state # type: ignore + self._previous_saved_state, + self._saved_state, # type: ignore ) @saved_state_needed @@ -1071,104 +1098,34 @@ def inspect_collection( return inspection_result @classmethod - def get_hidden_fields(cls): - return set( - attribute_name - for attribute_name, model_field in get_model_fields(cls).items() - if get_extra_field_info(model_field, "hidden") is True + def check_hidden_fields(cls): + hidden_fields = [ + (name, field) + for name, field in get_model_fields(cls).items() + if get_extra_field_info(field, "hidden") is True + ] + if not hidden_fields: + return + warnings.warn( + f"{cls.__name__}: 'hidden=True' is deprecated, please use 'exclude=True'", + DeprecationWarning, ) - - if IS_PYDANTIC_V2: - - def model_dump( - self, - *, - mode="python", - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - exclude_hidden: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - round_trip: bool = False, - warnings: bool = True, - ) -> "DictStrAny": - """ - Overriding of the respective method from Pydantic - Hides fields, marked as "hidden - """ - if exclude_hidden: - if isinstance(exclude, AbstractSet): - exclude = {*self._hidden_fields, *exclude} - elif isinstance(exclude, Mapping): - exclude = dict( - {k: True for k in self._hidden_fields}, **exclude - ) # type: ignore - elif exclude is None: - exclude = self._hidden_fields - - kwargs = { - "include": include, - "exclude": exclude, - "by_alias": by_alias, - "exclude_unset": exclude_unset, - "exclude_defaults": exclude_defaults, - "exclude_none": exclude_none, - "round_trip": round_trip, - "warnings": warnings, - } - - return super().model_dump(**kwargs) - - else: - - def dict( - self, - *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - by_alias: bool = False, - skip_defaults: bool = False, - exclude_hidden: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> "DictStrAny": - """ - Overriding of the respective method from Pydantic - Hides fields, marked as "hidden - """ - if exclude_hidden: - if isinstance(exclude, AbstractSet): - exclude = {*self._hidden_fields, *exclude} - elif isinstance(exclude, Mapping): - exclude = dict( - {k: True for k in self._hidden_fields}, **exclude - ) # type: ignore - elif exclude is None: - exclude = self._hidden_fields - - kwargs = { - "include": include, - "exclude": exclude, - "by_alias": by_alias, - "exclude_unset": exclude_unset, - "exclude_defaults": exclude_defaults, - "exclude_none": exclude_none, - } - - # TODO: Remove this check when skip_defaults are no longer supported - if skip_defaults: - kwargs["skip_defaults"] = skip_defaults - - return super().dict(**kwargs) + if IS_PYDANTIC_V2: + for name, field in hidden_fields: + field.exclude = True + del field.json_schema_extra["hidden"] + cls.model_rebuild(force=True) + else: + for name, field in hidden_fields: + field.field_info.exclude = True + del field.field_info.extra["hidden"] + cls.__exclude_fields__[name] = True @wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE) def validate_self(self, *args, **kwargs): - # TODO it can be sync, but needs some actions controller improvements if self.get_settings().validate_on_save: - parse_model(self.__class__, get_model_dump(self)) + new_model = parse_model(self.__class__, get_model_dump(self)) + merge_models(self, new_model) def to_ref(self): if self.id is None: diff --git a/bunnet/odm/fields.py b/bunnet/odm/fields.py index f406f39..eb60275 100644 --- a/bunnet/odm/fields.py +++ b/bunnet/odm/fields.py @@ -1,5 +1,6 @@ import sys from collections import OrderedDict +from dataclasses import dataclass from enum import Enum from typing import ( TYPE_CHECKING, @@ -19,6 +20,7 @@ from typing_extensions import get_args from typing import OrderedDict as OrderedDictType +from typing import Tuple from bson import DBRef, ObjectId from bson.errors import InvalidId @@ -65,13 +67,32 @@ from bunnet.odm.documents import DocType -def Indexed(typ, index_type=ASCENDING, **kwargs): +@dataclass(frozen=True) +class IndexedAnnotation: + _indexed: Tuple[int, Dict[str, Any]] + + +def Indexed(typ=None, index_type=ASCENDING, **kwargs): """ - Returns a subclass of `typ` with an extra attribute `_indexed` as a tuple: + If `typ` is defined, returns a subclass of `typ` with an extra attribute + `_indexed` as a tuple: - Index 0: `index_type` such as `pymongo.ASCENDING` - Index 1: `kwargs` passed to `IndexModel` When instantiated the type of the result will actually be `typ`. + + When `typ` is not defined, returns an `IndexedAnnotation` instance, to be + used as metadata in `Annotated` fields. + + Example: + ```py + # Both fields would have the same behavior + class MyModel(BaseModel): + field1: Indexed(str, unique=True) + field2: Annotated[str, Indexed(unique=True)] + ``` """ + if typ is None: + return IndexedAnnotation(_indexed=(index_type, kwargs)) class NewType(typ): _indexed = (index_type, kwargs) @@ -85,6 +106,12 @@ def __new__(cls, *args, **kwargs): def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: + custom_type = getattr( + typ, "__get_pydantic_core_schema__", None + ) + if custom_type is not None: + return custom_type(_source_type, _handler) + return core_schema.no_info_after_validator_function( lambda v: v, simple_ser_schema(typ.__name__), @@ -111,7 +138,7 @@ def validate(cls, v, _: ValidationInfo): v = v.decode("utf-8") try: return PydanticObjectId(v) - except InvalidId: + except (InvalidId, TypeError): raise ValueError("Id must be of type PydanticObjectId") @classmethod @@ -163,6 +190,8 @@ def __modify_schema__(cls, field_schema): PydanticObjectId ] = str # it is a workaround to force pydantic make json schema for this field +BunnetObjectId = PydanticObjectId + class ExpressionField(str): def __getitem__(self, item): @@ -247,6 +276,7 @@ class LinkInfo(BaseModel): document_class: Type[BaseModel] # Document class link_type: LinkTypes nested_links: Optional[Dict] = None + is_fetchable: bool = True T = TypeVar("T") diff --git a/bunnet/odm/interfaces/find.py b/bunnet/odm/interfaces/find.py index 61a2214..372e6ac 100644 --- a/bunnet/odm/interfaces/find.py +++ b/bunnet/odm/interfaces/find.py @@ -26,9 +26,11 @@ from bunnet.odm.settings.base import ItemSettings if TYPE_CHECKING: - from bunnet.odm.documents import DocType + from bunnet.odm.documents import Document + from bunnet.odm.views import View DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel) +FindType = TypeVar("FindType", bound=Union["Document", "View"]) class FindInterface: @@ -54,45 +56,51 @@ def get_settings(cls) -> ItemSettings: @overload @classmethod def find_one( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: None = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, with_children: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> FindOne["DocType"]: + ) -> FindOne[FindType]: ... @overload @classmethod def find_one( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Type["DocumentProjectionType"], session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, with_children: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> FindOne["DocumentProjectionType"]: ... @classmethod # type: ignore def find_one( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Optional[Type["DocumentProjectionType"]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, with_children: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> "DocType": + ) -> Union[FindOne[FindType], FindOne["DocumentProjectionType"]]: """ Find one document by criteria. - Returns [FindOne](https://roman-right.github.io/bunnet/api/queries/#findone) query object. + Returns [FindOne](query.md#findone) query object. When awaited this will either return a document or None if no document exists for the search criteria. :param args: *Mapping[str, Any] - search criteria @@ -100,7 +108,7 @@ def find_one( # type: ignore :param session: Optional[ClientSession] - pymongo session instance :param ignore_cache: bool :param **pymongo_kwargs: pymongo native parameters for find operation (if Document class contains links, this parameter must fit the respective parameter of the aggregate MongoDB function) - :return: [FindOne](https://roman-right.github.io/bunnet/api/queries/#findone) - find query instance + :return: [FindOne](query.md#findone) - find query instance """ args = cls._add_class_id_filter(args, with_children) return cls._find_one_query_class(document_model=cls).find_one( @@ -109,13 +117,15 @@ def find_one( # type: ignore session=session, ignore_cache=ignore_cache, fetch_links=fetch_links, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) @overload @classmethod def find_many( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: None = None, skip: Optional[int] = None, @@ -126,14 +136,16 @@ def find_many( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> FindMany["DocType"]: + ) -> FindMany[FindType]: ... @overload @classmethod def find_many( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Optional[Type["DocumentProjectionType"]] = None, skip: Optional[int] = None, @@ -144,13 +156,15 @@ def find_many( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @classmethod def find_many( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Optional[Type["DocumentProjectionType"]] = None, skip: Optional[int] = None, @@ -161,11 +175,13 @@ def find_many( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: """ Find many documents by criteria. - Returns [FindMany](https://roman-right.github.io/bunnet/api/queries/#findmany) query object + Returns [FindMany](query.md#findmany) query object :param args: *Mapping[str, Any] - search criteria :param skip: Optional[int] - The number of documents to omit. @@ -176,7 +192,7 @@ def find_many( # type: ignore :param ignore_cache: bool :param lazy_parse: bool :param **pymongo_kwargs: pymongo native parameters for find operation (if Document class contains links, this parameter must fit the respective parameter of the aggregate MongoDB function) - :return: [FindMany](https://roman-right.github.io/bunnet/api/queries/#findmany) - query instance + :return: [FindMany](query.md#findmany) - query instance """ args = cls._add_class_id_filter(args, with_children) return cls._find_many_query_class(document_model=cls).find_many( @@ -189,13 +205,15 @@ def find_many( # type: ignore ignore_cache=ignore_cache, fetch_links=fetch_links, lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) @overload @classmethod def find( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: None = None, skip: Optional[int] = None, @@ -206,14 +224,16 @@ def find( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> FindMany["DocType"]: + ) -> FindMany[FindType]: ... @overload @classmethod def find( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Type["DocumentProjectionType"], skip: Optional[int] = None, @@ -224,13 +244,15 @@ def find( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @classmethod def find( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], *args: Union[Mapping[str, Any], bool], projection_model: Optional[Type["DocumentProjectionType"]] = None, skip: Optional[int] = None, @@ -241,8 +263,10 @@ def find( # type: ignore fetch_links: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: """ The same as find_many """ @@ -257,13 +281,15 @@ def find( # type: ignore fetch_links=fetch_links, with_children=with_children, lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) @overload @classmethod def find_all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], skip: Optional[int] = None, limit: Optional[int] = None, sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, @@ -272,14 +298,16 @@ def find_all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> FindMany["DocType"]: + ) -> FindMany[FindType]: ... @overload @classmethod def find_all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], skip: Optional[int] = None, limit: Optional[int] = None, sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, @@ -288,13 +316,15 @@ def find_all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @classmethod def find_all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], skip: Optional[int] = None, limit: Optional[int] = None, sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, @@ -303,8 +333,10 @@ def find_all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: """ Get all the documents @@ -314,7 +346,7 @@ def find_all( # type: ignore :param projection_model: Optional[Type[BaseModel]] - projection model :param session: Optional[ClientSession] - pymongo session :param **pymongo_kwargs: pymongo native parameters for find operation (if Document class contains links, this parameter must fit the respective parameter of the aggregate MongoDB function) - :return: [FindMany](https://roman-right.github.io/bunnet/api/queries/#findmany) - query instance + :return: [FindMany](query.md#findmany) - query instance """ return cls.find_many( # type: ignore {}, @@ -326,13 +358,15 @@ def find_all( # type: ignore ignore_cache=ignore_cache, with_children=with_children, lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) @overload @classmethod def all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], projection_model: None = None, skip: Optional[int] = None, limit: Optional[int] = None, @@ -341,14 +375,16 @@ def all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> FindMany["DocType"]: + ) -> FindMany[FindType]: ... @overload @classmethod def all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], projection_model: Type["DocumentProjectionType"], skip: Optional[int] = None, limit: Optional[int] = None, @@ -357,13 +393,15 @@ def all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @classmethod def all( # type: ignore - cls: Type["DocType"], + cls: Type[FindType], projection_model: Optional[Type["DocumentProjectionType"]] = None, skip: Optional[int] = None, limit: Optional[int] = None, @@ -372,8 +410,10 @@ def all( # type: ignore ignore_cache: bool = False, with_children: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, - ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: + ) -> Union[FindMany[FindType], FindMany["DocumentProjectionType"]]: """ the same as find_all """ @@ -386,6 +426,8 @@ def all( # type: ignore ignore_cache=ignore_cache, with_children=with_children, lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) diff --git a/bunnet/odm/interfaces/update.py b/bunnet/odm/interfaces/update.py index 7f5ce1c..6fb8c81 100644 --- a/bunnet/odm/interfaces/update.py +++ b/bunnet/odm/interfaces/update.py @@ -48,7 +48,7 @@ class Sample(Document): ``` - Uses [Set operator](https://roman-right.github.io/bunnet/api/operators/update/#set) + Uses [Set operator](operators/update.md#set) :param expression: Dict[Union[ExpressionField, str], Any] - keys and values to set @@ -70,7 +70,7 @@ def current_date( """ Set current date - Uses [CurrentDate operator](https://roman-right.github.io/bunnet/api/operators/update/#currentdate) + Uses [CurrentDate operator](operators/update.md#currentdate) :param expression: Dict[Union[ExpressionField, str], Any] :param session: Optional[ClientSession] - pymongo session @@ -105,7 +105,7 @@ class Sample(Document): ``` - Uses [Inc operator](https://roman-right.github.io/bunnet/api/operators/update/#inc) + Uses [Inc operator](operators/update.md#inc) :param expression: Dict[Union[ExpressionField, str], Any] :param session: Optional[ClientSession] - pymongo session diff --git a/bunnet/odm/operators/update/array.py b/bunnet/odm/operators/update/array.py index 7566a36..fd1334a 100644 --- a/bunnet/odm/operators/update/array.py +++ b/bunnet/odm/operators/update/array.py @@ -22,7 +22,7 @@ class AddToSet(BaseUpdateArrayOperator): class Sample(Document): results: List[int] - AddToSet({Sample.results, 2}) + AddToSet({Sample.results: 2}) ``` Will return query object like @@ -48,7 +48,7 @@ class Pop(BaseUpdateArrayOperator): class Sample(Document): results: List[int] - Pop({Sample.results, 2}) + Pop({Sample.results: 2}) ``` Will return query object like @@ -74,7 +74,7 @@ class Pull(BaseUpdateArrayOperator): class Sample(Document): results: List[int] - Pull(In(Sample.result, [1,2,3,4,5]) + Pull(In(Sample.result: [1,2,3,4,5]) ``` Will return query object like diff --git a/bunnet/odm/operators/update/general.py b/bunnet/odm/operators/update/general.py index b684b31..1f684f8 100644 --- a/bunnet/odm/operators/update/general.py +++ b/bunnet/odm/operators/update/general.py @@ -81,7 +81,7 @@ class CurrentDate(BaseUpdateGeneralOperator): class Sample(Document): ts: datetime - CurrentDate({Sample.ts, True}) + CurrentDate({Sample.ts: True}) ``` Will return query object like @@ -107,7 +107,7 @@ class Inc(BaseUpdateGeneralOperator): class Sample(Document): one: int - Inc({Sample.one, 2}) + Inc({Sample.one: 2}) ``` Will return query object like @@ -133,7 +133,7 @@ class Min(BaseUpdateGeneralOperator): class Sample(Document): one: int - Min({Sample.one, 2}) + Min({Sample.one: 2}) ``` Will return query object like @@ -159,7 +159,7 @@ class Max(BaseUpdateGeneralOperator): class Sample(Document): one: int - Max({Sample.one, 2}) + Max({Sample.one: 2}) ``` Will return query object like @@ -185,7 +185,7 @@ class Mul(BaseUpdateGeneralOperator): class Sample(Document): one: int - Mul({Sample.one, 2}) + Mul({Sample.one: 2}) ``` Will return query object like diff --git a/bunnet/odm/queries/delete.py b/bunnet/odm/queries/delete.py index 7d9263b..8a8aeeb 100644 --- a/bunnet/odm/queries/delete.py +++ b/bunnet/odm/queries/delete.py @@ -10,6 +10,7 @@ from pymongo import DeleteMany as DeleteManyPyMongo from pymongo import DeleteOne as DeleteOnePyMongo +from pymongo.client_session import ClientSession from pymongo.results import DeleteResult from bunnet.odm.bulk import BulkWriter, Operation @@ -35,7 +36,7 @@ def __init__( ): self.document_model = document_model self.find_query = find_query - self.session = None + self.session: Optional[ClientSession] = None self.bulk_writer = bulk_writer self.pymongo_kwargs: Dict[str, Any] = pymongo_kwargs diff --git a/bunnet/odm/queries/find.py b/bunnet/odm/queries/find.py index af6a825..0c95359 100644 --- a/bunnet/odm/queries/find.py +++ b/bunnet/odm/queries/find.py @@ -44,7 +44,7 @@ ) from bunnet.odm.utils.dump import get_dict from bunnet.odm.utils.encoder import Encoder -from bunnet.odm.utils.find import construct_lookup_queries +from bunnet.odm.utils.find import construct_lookup_queries, split_text_query from bunnet.odm.utils.parsing import parse_obj from bunnet.odm.utils.projection import get_projection from bunnet.odm.utils.relations import convert_ids @@ -82,13 +82,15 @@ def __init__(self, document_model: Type["DocType"]): self.fetch_links: bool = False self.pymongo_kwargs: Dict[str, Any] = {} self.lazy_parse = False + self.nesting_depth: Optional[int] = None + self.nesting_depths_per_field: Optional[Dict[str, int]] = None def prepare_find_expressions(self): if self.document_model.get_link_fields() is not None: for i, query in enumerate(self.find_expressions): self.find_expressions[i] = convert_ids( query, - doc=self.document_model, + doc=self.document_model, # type: ignore fetch_links=self.fetch_links, ) @@ -187,6 +189,8 @@ def find_many( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindMany[FindQueryResultType]": ... @@ -203,6 +207,8 @@ def find_many( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindMany[FindQueryProjectionType]": ... @@ -218,6 +224,8 @@ def find_many( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> Union[ "FindMany[FindQueryResultType]", "FindMany[FindQueryProjectionType]" @@ -246,6 +254,8 @@ def find_many( self.ignore_cache = ignore_cache self.fetch_links = fetch_links self.pymongo_kwargs.update(pymongo_kwargs) + self.nesting_depth = nesting_depth + self.nesting_depths_per_field = nesting_depths_per_field if lazy_parse is True: self.lazy_parse = lazy_parse return self @@ -294,6 +304,8 @@ def find( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindMany[FindQueryResultType]": ... @@ -310,6 +322,8 @@ def find( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindMany[FindQueryProjectionType]": ... @@ -325,6 +339,8 @@ def find( ignore_cache: bool = False, fetch_links: bool = False, lazy_parse: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> Union[ "FindMany[FindQueryResultType]", "FindMany[FindQueryProjectionType]" @@ -342,6 +358,8 @@ def find( ignore_cache=ignore_cache, fetch_links=fetch_links, lazy_parse=lazy_parse, + nesting_depth=nesting_depth, + nesting_depths_per_field=nesting_depths_per_field, **pymongo_kwargs, ) @@ -471,11 +489,11 @@ def update_many( ) -> UpdateMany: """ Provide search criteria to the - [UpdateMany](https://roman-right.github.io/bunnet/api/queries/#updatemany) query + [UpdateMany](query.md#updatemany) query :param args: *Mapping[str,Any] - the modifications to apply. :param session: Optional[ClientSession] - :return: [UpdateMany](https://roman-right.github.io/bunnet/api/queries/#updatemany) query + :return: [UpdateMany](query.md#updatemany) query """ return cast( UpdateMany, @@ -494,10 +512,10 @@ def delete_many( **pymongo_kwargs, ) -> DeleteMany: """ - Provide search criteria to the [DeleteMany](https://roman-right.github.io/bunnet/api/queries/#deletemany) query + Provide search criteria to the [DeleteMany](query.md#deletemany) query :param session: - :return: [DeleteMany](https://roman-right.github.io/bunnet/api/queries/#deletemany) query + :return: [DeleteMany](query.md#deletemany) query """ # We need to cast here to tell mypy that we are sure about the type. # This is because delete may also return a DeleteOne type in general, and mypy can not be sure in this case @@ -543,28 +561,21 @@ def aggregate( AggregationQuery[FindQueryProjectionType], ]: """ - Provide search criteria to the [AggregationQuery](https://roman-right.github.io/bunnet/api/queries/#aggregationquery) + Provide search criteria to the [AggregationQuery](query.md#aggregationquery) :param aggregation_pipeline: list - aggregation pipeline. MongoDB doc: :param projection_model: Type[BaseModel] - Projection Model :param session: Optional[ClientSession] - PyMongo session :param ignore_cache: bool - :return:[AggregationQuery](https://roman-right.github.io/bunnet/api/queries/#aggregationquery) + :return:[AggregationQuery](query.md#aggregationquery) """ self.set_session(session=session) - find_query = self.get_filter_query() - if self.fetch_links: - find_aggregation_pipeline = self.build_aggregation_pipeline() - aggregation_pipeline = ( - find_aggregation_pipeline + aggregation_pipeline - ) - find_query = {} return self.AggregationQueryType( - aggregation_pipeline=aggregation_pipeline, - document_model=self.document_model, + self.document_model, + self.build_aggregation_pipeline(*aggregation_pipeline), + find_query={}, projection_model=projection_model, - find_query=find_query, ignore_cache=ignore_cache, **pymongo_kwargs, ).set_session(session=self.session) @@ -602,18 +613,47 @@ def _set_cache(self, data): self._cache_key, data ) - def build_aggregation_pipeline(self): - aggregation_pipeline: List[Dict[str, Any]] = construct_lookup_queries( # type: ignore - self.document_model - ) + def build_aggregation_pipeline(self, *extra_stages): + if self.fetch_links: + aggregation_pipeline: List[ + Dict[str, Any] + ] = construct_lookup_queries( + self.document_model, + nesting_depth=self.nesting_depth, + nesting_depths_per_field=self.nesting_depths_per_field, + ) + else: + aggregation_pipeline = [] filter_query = self.get_filter_query() - if "$text" in filter_query: - text_query = filter_query["$text"] - aggregation_pipeline.insert(0, {"$match": {"$text": text_query}}) - del filter_query["$text"] - aggregation_pipeline.append({"$match": filter_query}) + if filter_query: + text_queries, non_text_queries = split_text_query(filter_query) + + if text_queries: + aggregation_pipeline.insert( + 0, + { + "$match": ( + {"$and": text_queries} + if len(text_queries) > 1 + else text_queries[0] + ) + }, + ) + + if non_text_queries: + aggregation_pipeline.append( + { + "$match": ( + {"$and": non_text_queries} + if len(non_text_queries) > 1 + else non_text_queries[0] + ) + } + ) + if extra_stages: + aggregation_pipeline.extend(extra_stages) sort_pipeline = {"$sort": {i[0]: i[1] for i in self.sort_expressions}} if sort_pipeline["$sort"]: aggregation_pipeline.append(sort_pipeline) @@ -668,14 +708,7 @@ def count(self) -> int: if self.fetch_links: aggregation_pipeline: List[ Dict[str, Any] - ] = construct_lookup_queries(self.document_model) - - aggregation_pipeline.append({"$match": self.get_filter_query()}) - - if self.skip_number != 0: - aggregation_pipeline.append({"$skip": self.skip_number}) - if self.limit_number != 0: - aggregation_pipeline.append({"$limit": self.limit_number}) + ] = self.build_aggregation_pipeline() aggregation_pipeline.append({"$count": "count"}) @@ -739,6 +772,8 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindOne[FindQueryResultType]": ... @@ -751,6 +786,8 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> "FindOne[FindQueryProjectionType]": ... @@ -762,6 +799,8 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, **pymongo_kwargs, ) -> Union[ "FindOne[FindQueryResultType]", "FindOne[FindQueryProjectionType]" @@ -782,6 +821,8 @@ def find_one( self.ignore_cache = ignore_cache self.fetch_links = fetch_links self.pymongo_kwargs.update(pymongo_kwargs) + self.nesting_depth = nesting_depth + self.nesting_depths_per_field = nesting_depths_per_field return self def update( @@ -860,7 +901,7 @@ def update_one( **pymongo_kwargs, ) -> UpdateOne: """ - Create [UpdateOne](https://roman-right.github.io/bunnet/api/queries/#updateone) query using modifications and + Create [UpdateOne](query.md#updateone) query using modifications and provide search criteria there :param args: *Mapping[str,Any] - the modifications to apply :param session: Optional[ClientSession] - PyMongo sessions @@ -885,9 +926,9 @@ def delete_one( **pymongo_kwargs, ) -> DeleteOne: """ - Provide search criteria to the [DeleteOne](https://roman-right.github.io/bunnet/api/queries/#deleteone) query + Provide search criteria to the [DeleteOne](query.md#deleteone) query :param session: Optional[ClientSession] - PyMongo sessions - :return: [DeleteOne](https://roman-right.github.io/bunnet/api/queries/#deleteone) query + :return: [DeleteOne](query.md#deleteone) query """ # We need to cast here to tell mypy that we are sure about the type. # This is because delete may also return a DeleteOne type in general, and mypy can not be sure in this case @@ -935,7 +976,12 @@ def replace_one( Operation( operation=ReplaceOne, first_query=self.get_filter_query(), - second_query=Encoder(exclude={"_id"}).encode(document), + second_query=get_dict( + document, + to_db=True, + exclude={"_id"}, + keep_nulls=document.get_settings().keep_nulls, + ), object_class=self.document_model, pymongo_kwargs=self.pymongo_kwargs, ) @@ -949,6 +995,8 @@ def _find_one(self): session=self.session, fetch_links=self.fetch_links, projection_model=self.projection_model, + nesting_depth=self.nesting_depth, + nesting_depths_per_field=self.nesting_depths_per_field, **self.pymongo_kwargs, ).first_or_none() return self.document_model.get_motor_collection().find_one( diff --git a/bunnet/odm/queries/update.py b/bunnet/odm/queries/update.py index 17e5144..e678247 100644 --- a/bunnet/odm/queries/update.py +++ b/bunnet/odm/queries/update.py @@ -16,7 +16,7 @@ from pymongo import UpdateMany as UpdateManyPyMongo from pymongo import UpdateOne as UpdateOnePyMongo from pymongo.client_session import ClientSession -from pymongo.results import UpdateResult +from pymongo.results import InsertOneResult, UpdateResult from bunnet.odm.bulk import BulkWriter, Operation from bunnet.odm.interfaces.clone import CloneInterface @@ -335,7 +335,7 @@ def _update(self): def run( self, - ): + ) -> Union[UpdateResult, InsertOneResult, Optional["DocType"]]: """ Run the query :return: diff --git a/bunnet/odm/settings/document.py b/bunnet/odm/settings/document.py index da841a8..d4fd0b0 100644 --- a/bunnet/odm/settings/document.py +++ b/bunnet/odm/settings/document.py @@ -27,6 +27,9 @@ class DocumentSettings(ItemSettings): keep_nulls: bool = True + max_nesting_depths_per_field: dict = Field(default_factory=dict) + max_nesting_depth: int = 3 + if IS_PYDANTIC_V2: model_config = ConfigDict( arbitrary_types_allowed=True, diff --git a/bunnet/odm/settings/timeseries.py b/bunnet/odm/settings/timeseries.py index 48c04a0..b630e6a 100644 --- a/bunnet/odm/settings/timeseries.py +++ b/bunnet/odm/settings/timeseries.py @@ -22,15 +22,21 @@ class TimeSeriesConfig(BaseModel): time_field: str meta_field: Optional[str] = None granularity: Optional[Granularity] = None - expire_after_seconds: Optional[float] = None + bucket_max_span_seconds: Optional[int] = None + bucket_rounding_second: Optional[int] = None + expire_after_seconds: Optional[int] = None def build_query(self, collection_name: str) -> Dict[str, Any]: res: Dict[str, Any] = {"name": collection_name} - timeseries = {"timeField": self.time_field} + timeseries: Dict[str, Any] = {"timeField": self.time_field} if self.meta_field is not None: timeseries["metaField"] = self.meta_field if self.granularity is not None: timeseries["granularity"] = self.granularity + if self.bucket_max_span_seconds is not None: + timeseries["bucketMaxSpanSeconds"] = self.bucket_max_span_seconds + if self.bucket_rounding_second is not None: + timeseries["bucketRoundingSeconds"] = self.bucket_rounding_second res["timeseries"] = timeseries if self.expire_after_seconds is not None: res["expireAfterSeconds"] = self.expire_after_seconds diff --git a/bunnet/odm/settings/view.py b/bunnet/odm/settings/view.py index c9e3d92..21dc076 100644 --- a/bunnet/odm/settings/view.py +++ b/bunnet/odm/settings/view.py @@ -1,8 +1,13 @@ from typing import Any, Dict, List, Type, Union +from pydantic import Field + from bunnet.odm.settings.base import ItemSettings class ViewSettings(ItemSettings): source: Union[str, Type] pipeline: List[Dict[str, Any]] + + max_nesting_depths_per_field: dict = Field(default_factory=dict) + max_nesting_depth: int = 3 diff --git a/bunnet/odm/utils/encoder.py b/bunnet/odm/utils/encoder.py index b6b5e7f..9fc8cfe 100644 --- a/bunnet/odm/utils/encoder.py +++ b/bunnet/odm/utils/encoder.py @@ -7,6 +7,7 @@ import pathlib import re import uuid +from enum import Enum from typing import ( Any, Callable, @@ -36,6 +37,7 @@ pathlib.PurePath: str, pydantic.SecretBytes: pydantic.SecretBytes.get_secret_value, pydantic.SecretStr: pydantic.SecretStr.get_secret_value, + datetime.date: lambda d: datetime.datetime.combine(d, datetime.time.min), datetime.timedelta: operator.methodcaller("total_seconds"), enum.Enum: operator.attrgetter("value"), Link: operator.attrgetter("ref"), @@ -58,6 +60,8 @@ bson.Binary, bson.DBRef, bson.Decimal128, + bson.MaxKey, + bson.MinKey, bson.ObjectId, ) @@ -128,21 +132,14 @@ def encode(self, obj: Any) -> Any: items = self._iter_model_items(obj) return {key: self.encode(value) for key, value in items} if isinstance(obj, Mapping): - return {key: self.encode(value) for key, value in obj.items()} + return { + key if isinstance(key, Enum) else str(key): self.encode(value) + for key, value in obj.items() + } if isinstance(obj, Iterable): return [self.encode(value) for value in obj] - errors = [] - try: - data = dict(obj) - except Exception as e: - errors.append(e) - try: - data = vars(obj) - except Exception as e: - errors.append(e) - raise ValueError(errors) - return self.encode(data) + raise ValueError(f"Cannot encode {obj!r}") def _iter_model_items( self, obj: pydantic.BaseModel diff --git a/bunnet/odm/utils/find.py b/bunnet/odm/utils/find.py index 148d589..c770eb1 100644 --- a/bunnet/odm/utils/find.py +++ b/bunnet/odm/utils/find.py @@ -1,24 +1,36 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -from bunnet.exceptions import NotSupported from bunnet.odm.fields import LinkInfo, LinkTypes -from bunnet.odm.interfaces.detector import ModelType if TYPE_CHECKING: from bunnet import Document -def construct_lookup_queries(cls: Type["Document"]) -> List[Dict[str, Any]]: - if cls.get_model_type() == ModelType.UnionDoc: - raise NotSupported("UnionDoc doesn't support link fetching") +# TODO: check if this is the most efficient way for +# appending subqueries to the queries var + + +def construct_lookup_queries( + cls: Type["Document"], + nesting_depth: Optional[int] = None, + nesting_depths_per_field: Optional[Dict[str, int]] = None, +) -> List[Dict[str, Any]]: queries: List = [] link_fields = cls.get_link_fields() if link_fields is not None: for link_info in link_fields.values(): + final_nesting_depth = ( + nesting_depths_per_field.get(link_info.field_name, None) + if nesting_depths_per_field is not None + else None + ) + if final_nesting_depth is None: + final_nesting_depth = nesting_depth construct_query( link_info=link_info, queries=queries, database_major_version=cls._database_major_version, + current_depth=final_nesting_depth, ) return queries @@ -27,7 +39,12 @@ def construct_query( link_info: LinkInfo, queries: List, database_major_version: int, + current_depth: Optional[int] = None, ): + if link_info.is_fetchable is False or ( + current_depth is not None and current_depth <= 0 + ): + return if link_info.link_type in [ LinkTypes.DIRECT, LinkTypes.OPTIONAL_DIRECT, @@ -66,6 +83,9 @@ def construct_query( }, {"$unset": f"_link_{link_info.field_name}"}, ] # type: ignore + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) if link_info.nested_links is not None: lookup_steps[0]["$lookup"]["pipeline"] = [] # type: ignore for nested_link in link_info.nested_links: @@ -73,6 +93,7 @@ def construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_steps[0]["$lookup"]["pipeline"], # type: ignore database_major_version=database_major_version, + current_depth=new_depth, ) queries += lookup_steps @@ -118,11 +139,15 @@ def construct_query( }, {"$unset": f"_link_{link_info.field_name}"}, ] + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) for nested_link in link_info.nested_links: construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_steps[0]["$lookup"]["pipeline"], # type: ignore database_major_version=database_major_version, + current_depth=new_depth, ) queries += lookup_steps @@ -164,6 +189,9 @@ def construct_query( }, {"$unset": f"_link_{link_info.field_name}"}, ] # type: ignore + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) if link_info.nested_links is not None: lookup_steps[0]["$lookup"]["pipeline"] = [] # type: ignore for nested_link in link_info.nested_links: @@ -171,6 +199,7 @@ def construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_steps[0]["$lookup"]["pipeline"], # type: ignore database_major_version=database_major_version, + current_depth=new_depth, ) queries += lookup_steps @@ -219,11 +248,15 @@ def construct_query( }, {"$unset": f"_link_{link_info.field_name}"}, ] + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) for nested_link in link_info.nested_links: construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_steps[0]["$lookup"]["pipeline"], # type: ignore database_major_version=database_major_version, + current_depth=new_depth, ) queries += lookup_steps @@ -242,7 +275,9 @@ def construct_query( } } ) - + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) if link_info.nested_links is not None: queries[-1]["$lookup"]["pipeline"] = [] for nested_link in link_info.nested_links: @@ -250,6 +285,7 @@ def construct_query( link_info=link_info.nested_links[nested_link], queries=queries[-1]["$lookup"]["pipeline"], database_major_version=database_major_version, + current_depth=new_depth, ) else: lookup_step = { @@ -262,12 +298,15 @@ def construct_query( ], } } - + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) for nested_link in link_info.nested_links: construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_step["$lookup"]["pipeline"], database_major_version=database_major_version, + current_depth=new_depth, ) queries.append(lookup_step) @@ -286,7 +325,9 @@ def construct_query( } } ) - + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) if link_info.nested_links is not None: queries[-1]["$lookup"]["pipeline"] = [] for nested_link in link_info.nested_links: @@ -294,6 +335,7 @@ def construct_query( link_info=link_info.nested_links[nested_link], queries=queries[-1]["$lookup"]["pipeline"], database_major_version=database_major_version, + current_depth=new_depth, ) else: lookup_step = { @@ -315,13 +357,47 @@ def construct_query( ], } } - + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) for nested_link in link_info.nested_links: construct_query( link_info=link_info.nested_links[nested_link], queries=lookup_step["$lookup"]["pipeline"], database_major_version=database_major_version, + current_depth=new_depth, ) queries.append(lookup_step) return queries + + +def split_text_query( + query: Dict[str, Any] +) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + """Divide query into text and non-text matches + + :param query: Dict[str, Any] - query dict + :return: Tuple[Dict[str, Any], Dict[str, Any]] - text and non-text queries, + respectively + """ + + root_text_query_args: Dict[str, Any] = query.get("$text", None) + root_non_text_queries: Dict[str, Any] = { + k: v for k, v in query.items() if k not in {"$text", "$and"} + } + + text_queries: List[Dict[str, Any]] = ( + [{"$text": root_text_query_args}] if root_text_query_args else [] + ) + non_text_queries: List[Dict[str, Any]] = ( + [root_non_text_queries] if root_non_text_queries else [] + ) + + for match_case in query.get("$and", []): + if "$text" in match_case: + text_queries.append(match_case) + else: + non_text_queries.append(match_case) + + return text_queries, non_text_queries diff --git a/bunnet/odm/utils/init.py b/bunnet/odm/utils/init.py index 6b3e30d..a7b38fa 100644 --- a/bunnet/odm/utils/init.py +++ b/bunnet/odm/utils/init.py @@ -5,10 +5,10 @@ from bunnet.odm.utils.pydantic import ( IS_PYDANTIC_V2, get_extra_field_info, - get_field_type, get_model_fields, parse_model, ) +from bunnet.odm.utils.typing import get_index_attributes if sys.version_info >= (3, 8): from typing import get_args, get_origin @@ -17,7 +17,6 @@ import importlib import inspect -from copy import copy from typing import ( # type: ignore List, Optional, @@ -65,6 +64,7 @@ def __init__( ] = None, allow_index_dropping: bool = False, recreate_views: bool = False, + multiprocessing_mode: bool = False, ): """ Bunnet initializer @@ -75,8 +75,12 @@ def __init__( or strings with dot separated paths :param allow_index_dropping: bool - if index dropping is allowed. Default False + :param recreate_views: bool - if views should be recreated. Default False + :param multiprocessing_mode: bool - if multiprocessing mode is on + it will patch the motor client to use process's event loop. :return: None """ + self.inited_classes: List[Type] = [] self.allow_index_dropping = allow_index_dropping self.recreate_views = recreate_views @@ -302,6 +306,22 @@ def detect_link( ) return None + def check_nested_links(self, link_info: LinkInfo, current_depth: int): + if current_depth == 1: + return + for k, v in get_model_fields(link_info.document_class).items(): + nested_link_info = self.detect_link(v, k) + if nested_link_info is None: + continue + + if link_info.nested_links is None: + link_info.nested_links = {} + link_info.nested_links[k] = nested_link_info + new_depth = ( + current_depth - 1 if current_depth is not None else None + ) + self.check_nested_links(nested_link_info, current_depth=new_depth) + # Document @staticmethod @@ -340,27 +360,6 @@ def init_document_fields(self, cls) -> None: if not IS_PYDANTIC_V2: self.update_forward_refs(cls) - def check_nested_links( - link_info: LinkInfo, prev_models: List[Type[BaseModel]] - ): - if link_info.document_class in prev_models: - return - if not IS_PYDANTIC_V2: - self.update_forward_refs(link_info.document_class) - for k, v in get_model_fields(link_info.document_class).items(): - nested_link_info = self.detect_link(v, k) - if nested_link_info is None: - continue - - if link_info.nested_links is None: - link_info.nested_links = {} - link_info.nested_links[k] = nested_link_info - new_prev_models = copy(prev_models) - new_prev_models.append(link_info.document_class) - check_nested_links( - nested_link_info, prev_models=new_prev_models - ) - if cls._link_fields is None: cls._link_fields = {} for k, v in get_model_fields(cls).items(): @@ -368,11 +367,22 @@ def check_nested_links( setattr(cls, k, ExpressionField(path)) link_info = self.detect_link(v, k) + depth_level = cls.get_settings().max_nesting_depths_per_field.get( + k, None + ) + if depth_level is None: + depth_level = cls.get_settings().max_nesting_depth if link_info is not None: - cls._link_fields[k] = link_info - check_nested_links(link_info, prev_models=[]) + if depth_level > 0 or depth_level is None: + cls._link_fields[k] = link_info + self.check_nested_links( + link_info, current_depth=depth_level + ) + elif depth_level <= 0: + link_info.is_fetchable = False + cls._link_fields[k] = link_info - cls._hidden_fields = cls.get_hidden_fields() + cls.check_hidden_fields() @staticmethod def init_actions(cls): @@ -457,21 +467,24 @@ def init_indexes(self, cls, allow_index_dropping: bool = False): new_indexes = [] # Indexed field wrapped with Indexed() + indexed_fields = ( + (k, fvalue, get_index_attributes(fvalue)) + for k, fvalue in get_model_fields(cls).items() + ) found_indexes = [ IndexModelField( IndexModel( [ ( fvalue.alias or k, - fvalue.annotation._indexed[0], + indexed_attrs[0], ) ], - **fvalue.annotation._indexed[1], + **indexed_attrs[1], ) ) - for k, fvalue in get_model_fields(cls).items() - if hasattr(get_field_type(fvalue), "_indexed") - and get_field_type(fvalue)._indexed + for k, fvalue, indexed_attrs in indexed_fields + if indexed_attrs is not None ] if document_settings.merge_indexes: @@ -582,35 +595,26 @@ def init_view_fields(self, cls) -> None: :return: None """ - def check_nested_links( - link_info: LinkInfo, prev_models: List[Type[BaseModel]] - ): - if link_info.document_class in prev_models: - return - for k, v in get_model_fields(link_info.document_class).items(): - nested_link_info = self.detect_link(v, k) - if nested_link_info is None: - continue - - if link_info.nested_links is None: - link_info.nested_links = {} - link_info.nested_links[k] = nested_link_info - new_prev_models = copy(prev_models) - new_prev_models.append(link_info.document_class) - check_nested_links( - nested_link_info, prev_models=new_prev_models - ) - if cls._link_fields is None: cls._link_fields = {} for k, v in get_model_fields(cls).items(): path = v.alias or k setattr(cls, k, ExpressionField(path)) - link_info = self.detect_link(v, k) + depth_level = cls.get_settings().max_nesting_depths_per_field.get( + k, None + ) + if depth_level is None: + depth_level = cls.get_settings().max_nesting_depth if link_info is not None: - cls._link_fields[k] = link_info - check_nested_links(link_info, prev_models=[]) + if depth_level > 0: + cls._link_fields[k] = link_info + self.check_nested_links( + link_info, current_depth=depth_level + ) + elif depth_level <= 0: + link_info.is_fetchable = False + cls._link_fields[k] = link_info def init_view_collection(self, cls): """ @@ -721,6 +725,7 @@ def init_bunnet( ] = None, allow_index_dropping: bool = False, recreate_views: bool = False, + multiprocessing_mode: bool = False, ): """ Beanie initialization @@ -731,6 +736,9 @@ def init_bunnet( or strings with dot separated paths :param allow_index_dropping: bool - if index dropping is allowed. Default False + :param recreate_views: bool - if views should be recreated. Default False + :param multiprocessing_mode: bool - if multiprocessing mode is on + it will patch the motor client to use process's event loop. Default False :return: None """ @@ -740,4 +748,5 @@ def init_bunnet( document_models=document_models, allow_index_dropping=allow_index_dropping, recreate_views=recreate_views, + multiprocessing_mode=multiprocessing_mode, ).run() diff --git a/bunnet/odm/utils/parsing.py b/bunnet/odm/utils/parsing.py index 460e132..638432e 100644 --- a/bunnet/odm/utils/parsing.py +++ b/bunnet/odm/utils/parsing.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, Any, Type, Union +from typing import TYPE_CHECKING, Any, Dict, Type, Union from pydantic import BaseModel from bunnet.exceptions import ( + ApplyChangesException, DocWasNotRegisteredInUnionClass, UnionHasNoRegisteredDocs, ) @@ -22,10 +23,6 @@ def merge_models(left: BaseModel, right: BaseModel) -> None: """ from bunnet.odm.fields import Link - if hasattr(left, "_previous_revision_id") and hasattr( - right, "_previous_revision_id" - ): - left._previous_revision_id = right._previous_revision_id # type: ignore for k, right_value in right.__iter__(): left_value = getattr(left, k) if isinstance(right_value, BaseModel) and isinstance( @@ -49,11 +46,50 @@ def merge_models(left: BaseModel, right: BaseModel) -> None: left.__setattr__(k, right_value) -def save_state_swap_revision(item: BaseModel): +def apply_changes( + changes: Dict[str, Any], target: Union[BaseModel, Dict[str, Any]] +): + for key, value in changes.items(): + if "." in key: + key_parts = key.split(".") + current_target = target + try: + for part in key_parts[:-1]: + if isinstance(current_target, dict): + current_target = current_target[part] + elif isinstance(current_target, BaseModel): + current_target = getattr(current_target, part) + else: + raise ApplyChangesException( + f"Unexpected type of target: {type(target)}" + ) + final_key = key_parts[-1] + if isinstance(current_target, dict): + current_target[final_key] = value + elif isinstance(current_target, BaseModel): + setattr(current_target, final_key, value) + else: + raise ApplyChangesException( + f"Unexpected type of target: {type(target)}" + ) + except (KeyError, AttributeError) as e: + raise ApplyChangesException( + f"Failed to apply change for key '{key}': {e}" + ) + else: + if isinstance(target, dict): + target[key] = value + elif isinstance(target, BaseModel): + setattr(target, key, value) + else: + raise ApplyChangesException( + f"Unexpected type of target: {type(target)}" + ) + + +def save_state(item: BaseModel): if hasattr(item, "_save_state"): item._save_state() # type: ignore - if hasattr(item, "_swap_revision"): - item._swap_revision() # type: ignore def parse_obj( @@ -108,5 +144,5 @@ def parse_obj( o._saved_state = {"_id": o.id} return o result = parse_model(model, data) - save_state_swap_revision(result) + save_state(result) return result diff --git a/bunnet/odm/utils/pydantic.py b/bunnet/odm/utils/pydantic.py index f2446eb..2414e6a 100644 --- a/bunnet/odm/utils/pydantic.py +++ b/bunnet/odm/utils/pydantic.py @@ -55,8 +55,8 @@ def get_config_value(model, parameter: str): return getattr(model.Config, parameter, None) -def get_model_dump(model): +def get_model_dump(model, *args, **kwargs): if IS_PYDANTIC_V2: - return model.model_dump() + return model.model_dump(*args, **kwargs) else: - return model.dict() + return model.dict(*args, **kwargs) diff --git a/bunnet/odm/utils/relations.py b/bunnet/odm/utils/relations.py index 1d5de32..b65e0ae 100644 --- a/bunnet/odm/utils/relations.py +++ b/bunnet/odm/utils/relations.py @@ -1,4 +1,6 @@ +from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Dict +from typing import Mapping as MappingType from bunnet.odm.fields import ( ExpressionField, @@ -12,7 +14,7 @@ def convert_ids( - query: Dict[str, Any], doc: "Document", fetch_links: bool + query: MappingType[str, Any], doc: "Document", fetch_links: bool ) -> Dict[str, Any]: # TODO add all the cases new_query = {} @@ -31,9 +33,16 @@ def convert_ids( new_k = f"{k_splitted[0]}.$id" else: new_k = k - - if isinstance(v, dict): + new_v: Any + if isinstance(v, Mapping): new_v = convert_ids(v, doc, fetch_links) + elif isinstance(v, list): + new_v = [ + convert_ids(ele, doc, fetch_links) + if isinstance(ele, Mapping) + else ele + for ele in v + ] else: new_v = v diff --git a/bunnet/odm/utils/state.py b/bunnet/odm/utils/state.py index 7d9afaf..7aab502 100644 --- a/bunnet/odm/utils/state.py +++ b/bunnet/odm/utils/state.py @@ -53,13 +53,3 @@ def wrapper(self: "DocType", *args, **kwargs): return result return wrapper - - -def swap_revision_after(f: Callable): - @wraps(f) - def wrapper(self: "DocType", *args, **kwargs): - result = f(self, *args, **kwargs) - self._swap_revision() - return result - - return wrapper diff --git a/bunnet/odm/utils/typing.py b/bunnet/odm/utils/typing.py index 3aa3b8e..5bb8d4a 100644 --- a/bunnet/odm/utils/typing.py +++ b/bunnet/odm/utils/typing.py @@ -1,13 +1,16 @@ +import inspect import sys +from typing import Any, Dict, Optional, Tuple, Type + +from bunnet.odm.fields import IndexedAnnotation + +from .pydantic import IS_PYDANTIC_V2, get_field_type if sys.version_info >= (3, 8): from typing import get_args, get_origin else: from typing_extensions import get_args, get_origin -import inspect -from typing import Any, Type - def extract_id_class(annotation) -> Type[Any]: if get_origin(annotation) is not None: @@ -20,3 +23,52 @@ def extract_id_class(annotation) -> Type[Any]: if inspect.isclass(annotation): return annotation raise ValueError("Unknown annotation: {}".format(annotation)) + + +def get_index_attributes(field) -> Optional[Tuple[int, Dict[str, Any]]]: + """Gets the index attributes from the field, if it is indexed. + + :param field: The field to get the index attributes from. + + :return: The index attributes, if the field is indexed. Otherwise, None. + """ + # For fields that are directly typed with `Indexed()`, the type will have + # an `_indexed` attribute. + field_type = get_field_type(field) + if hasattr(field_type, "_indexed"): + return getattr(field_type, "_indexed", None) + + # For fields that are use `Indexed` within `Annotated`, the field will have + # metadata that might contain an `IndexedAnnotation` instance. + if IS_PYDANTIC_V2: + # In Pydantic 2, the field has a `metadata` attribute with + # the annotations. + metadata = getattr(field, "metadata", None) + elif hasattr(field, "annotation") and hasattr( + field.annotation, "__metadata__" + ): + # In Pydantic 1, the field has an `annotation` attribute with the + # type assigned to the field. If the type is annotated, it will + # have a `__metadata__` attribute with the annotations. + metadata = field.annotation.__metadata__ + else: + return None + + if metadata is None: + return None + + try: + iter(metadata) + except TypeError: + return None + + indexed_annotation = next( + ( + annotation + for annotation in metadata + if isinstance(annotation, IndexedAnnotation) + ), + None, + ) + + return getattr(indexed_annotation, "_indexed", None) diff --git a/docs/changelog.md b/docs/changelog.md index 63d1629..f74df5b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,14 @@ Bunnet project +## [1.3.0] - 2023-10-14 + +### Sync With Beanie | 2024.02 +- Author - [Roman Right](https://github.com/roman-right) +- PR + +[1.3.0]: https://pypi.org/project/bunnet/1.3.0 + ## [1.2.0] - 2023-10-14 ### Sync With Beanie | 2023.10 diff --git a/pyproject.toml b/pyproject.toml index 9ed0eb7..c0e1f97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "bunnet" -version = "1.2.0" +version = "1.3.0" description = "Synchronous Python ODM for MongoDB" readme = "README.md" requires-python = ">=3.7,<4.0" @@ -44,7 +44,8 @@ test = [ "httpx>=0.23.0", "fastapi>=0.100", "pydantic-settings>=2", - "pydantic-extra-types>=2" + "pydantic-extra-types>=2", + "pydantic[email]", ] doc = [ "Pygments>=2.8.0", diff --git a/tests/migrations/migrations_for_test/break/20210413211219_break.py b/tests/migrations/migrations_for_test/break/20210413211219_break.py index 501e7b2..3a2d4a3 100644 --- a/tests/migrations/migrations_for_test/break/20210413211219_break.py +++ b/tests/migrations/migrations_for_test/break/20210413211219_break.py @@ -32,8 +32,8 @@ class Forward: @iterative_migration(batch_size=2) def name_to_title(self, input_document: OldNote, output_document: Note): output_document.title = input_document.name - if output_document.title == "5": - raise Exception + if output_document.title > "5": + output_document.name = "5" class Backward: diff --git a/tests/migrations/test_break.py b/tests/migrations/test_break.py index 3f0d599..a94de2f 100644 --- a/tests/migrations/test_break.py +++ b/tests/migrations/test_break.py @@ -52,7 +52,7 @@ def test_migration_break(settings, notes, db): with pytest.raises(Exception): run_migrate(migration_settings) - init_bunnet(database=db, document_models=[Note]) + init_bunnet(database=db, document_models=[OldNote]) inspection = OldNote.inspect_collection() assert inspection.status == InspectionStatuses.OK notes = OldNote.get_motor_collection().find().to_list(length=100) diff --git a/tests/migrations/test_free_fall.py b/tests/migrations/test_free_fall.py index 3190ae0..d573bc4 100644 --- a/tests/migrations/test_free_fall.py +++ b/tests/migrations/test_free_fall.py @@ -53,6 +53,29 @@ def test_migration_free_fall(settings, notes, db): ) run_migrate(migration_settings) + init_bunnet(database=db, document_models=[Note]) + inspection = Note.inspect_collection() + assert inspection.status == InspectionStatuses.OK + note = Note.find_one({}).run() + assert note.title == "0" + + migration_settings.direction = RunningDirections.BACKWARD + run_migrate(migration_settings) + inspection = OldNote.inspect_collection() + assert inspection.status == InspectionStatuses.OK + note = OldNote.find_one({}).run() + assert note.name == "0" + + +def test_migration_free_fall_no_use_transactions(settings, notes, db): + migration_settings = MigrationSettings( + connection_uri=settings.mongodb_dsn, + database_name=settings.mongodb_db_name, + path="tests/migrations/migrations_for_test/free_fall", + use_transaction=False, + ) + run_migrate(migration_settings) + init_bunnet(database=db, document_models=[Note]) inspection = Note.inspect_collection() assert inspection.status == InspectionStatuses.OK diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index 44de3f7..329b332 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -20,6 +20,7 @@ DocumentMultiModelTwo, DocumentTestModel, DocumentTestModelFailInspection, + DocumentTestModelIndexFlagsAnnotated, DocumentTestModelWithComplexIndex, DocumentTestModelWithCustomCollectionName, DocumentTestModelWithIndexFlags, @@ -27,23 +28,29 @@ DocumentTestModelWithLink, DocumentTestModelWithSimpleIndex, DocumentToBeLinked, + DocumentToTestSync, DocumentUnion, DocumentWithActions, DocumentWithActions2, DocumentWithBackLink, + DocumentWithBackLinkForNesting, DocumentWithBsonBinaryField, DocumentWithBsonEncodersFiledsTypes, + DocumentWithComplexDictKey, DocumentWithCustomFiledsTypes, DocumentWithCustomIdInt, DocumentWithCustomIdUUID, DocumentWithCustomInit, DocumentWithDecimalField, + DocumentWithDeprecatedHiddenField, DocumentWithExtras, DocumentWithHttpUrlField, + DocumentWithIndexedObjectId, DocumentWithIndexMerging1, DocumentWithIndexMerging2, DocumentWithKeepNullsFalse, DocumentWithLink, + DocumentWithLinkForNesting, DocumentWithList, DocumentWithListBackLink, DocumentWithListLink, @@ -71,6 +78,7 @@ LinkDocumentForTextSeacrh, Lock, LockWithRevision, + LongSelfLink, LoopedLinksA, LoopedLinksB, Nested, @@ -196,9 +204,11 @@ def init(db): DocumentTestModelWithSimpleIndex, DocumentTestModelWithIndexFlags, DocumentTestModelWithIndexFlagsAliases, + DocumentTestModelIndexFlagsAnnotated, DocumentTestModelWithComplexIndex, DocumentTestModelFailInspection, DocumentWithBsonEncodersFiledsTypes, + DocumentWithDeprecatedHiddenField, DocumentWithCustomFiledsTypes, DocumentWithCustomIdUUID, DocumentWithCustomIdInt, @@ -273,6 +283,12 @@ def init(db): DocWithCallWrapper, DocumentWithOptionalBackLink, DocumentWithOptionalListBackLink, + DocumentWithComplexDictKey, + DocumentWithIndexedObjectId, + DocumentToTestSync, + DocumentWithLinkForNesting, + DocumentWithBackLinkForNesting, + LongSelfLink, ] init_bunnet( database=db, diff --git a/tests/odm/documents/test_inheritance.py b/tests/odm/documents/test_inheritance.py index 0c6ee3e..3043141 100644 --- a/tests/odm/documents/test_inheritance.py +++ b/tests/odm/documents/test_inheritance.py @@ -53,7 +53,6 @@ def test_inheritance(self, db): assert isinstance(updated_bike, Bike) assert updated_bike.color == "yellow" - assert Car._parent is Vehicle assert Bus._parent is Car assert len(big_bicycles) == 1 diff --git a/tests/odm/documents/test_init.py b/tests/odm/documents/test_init.py index c5e74b5..4d4b408 100644 --- a/tests/odm/documents/test_init.py +++ b/tests/odm/documents/test_init.py @@ -7,7 +7,9 @@ from bunnet.odm.utils.init import init_bunnet from bunnet.odm.utils.projection import get_projection from tests.odm.models import ( + Color, DocumentTestModel, + DocumentTestModelIndexFlagsAnnotated, DocumentTestModelStringImport, DocumentTestModelWithComplexIndex, DocumentTestModelWithCustomCollectionName, @@ -106,6 +108,31 @@ def test_flagged_index_creation_with_alias(): } +def test_annotated_index_creation(): + collection = DocumentTestModelIndexFlagsAnnotated.get_motor_collection() + index_info = collection.index_information() + assert index_info["str_index_text"]["key"] == [ + ("_fts", "text"), + ("_ftsx", 1), + ] + assert index_info["str_index_annotated_1"] == { + "key": [("str_index_annotated", 1)], + "v": 2, + } + + assert index_info["uuid_index_annotated_1"] == { + "key": [("uuid_index_annotated", 1)], + "unique": True, + "v": 2, + } + if "uuid_index" in index_info: + assert index_info["uuid_index"] == { + "key": [("uuid_index", 1)], + "unique": True, + "v": 2, + } + + def test_complex_index_creation(): collection = DocumentTestModelWithComplexIndex.get_motor_collection() index_info = collection.index_information() @@ -124,6 +151,8 @@ def test_index_dropping_is_allowed(db): init_bunnet( database=db, document_models=[DocumentTestModelWithComplexIndex] ) + collection = DocumentTestModelWithComplexIndex.get_motor_collection() + init_bunnet( database=db, document_models=[DocumentTestModelWithDroppedIndex], @@ -275,3 +304,20 @@ def test_merge_indexes(): def test_custom_init(): assert DocumentWithCustomInit.s == "TEST2" + + +def test_index_on_custom_types(db): + class Sample1(Document): + name: Indexed(Color, unique=True) + + class Settings: + name = "sample" + + db.drop_collection("sample") + + init_bunnet( + database=db, + document_models=[Sample1], + ) + + db.drop_collection("sample") diff --git a/tests/odm/documents/test_revision.py b/tests/odm/documents/test_revision.py index 5bd33b0..7aad722 100644 --- a/tests/odm/documents/test_revision.py +++ b/tests/odm/documents/test_revision.py @@ -4,7 +4,11 @@ from bunnet import BulkWriter from bunnet.exceptions import RevisionIdWasChanged from bunnet.odm.operators.update.general import Inc -from tests.odm.models import DocumentWithRevisionTurnedOn +from tests.odm.models import ( + DocumentWithRevisionTurnedOn, + LockWithRevision, + WindowWithRevision, +) def test_replace(): @@ -22,7 +26,7 @@ def test_replace(): found_doc.num_1 += 1 found_doc.replace() - doc._previous_revision_id = "wrong" + doc.revision_id = "wrong" doc.num_1 = 4 with pytest.raises(RevisionIdWasChanged): doc.replace() @@ -43,7 +47,7 @@ def test_update(): found_doc = DocumentWithRevisionTurnedOn.get(doc.id).run() found_doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) - doc._previous_revision_id = "wrong" + doc.revision_id = "wrong" with pytest.raises(RevisionIdWasChanged): doc.update(Inc({DocumentWithRevisionTurnedOn.num_1: 1})) @@ -68,7 +72,7 @@ def test_save_changes(): found_doc.num_1 += 1 found_doc.save_changes() - doc._previous_revision_id = "wrong" + doc.revision_id = "wrong" doc.num_1 = 4 with pytest.raises(RevisionIdWasChanged): doc.save_changes() @@ -91,7 +95,7 @@ def test_save(): found_doc.num_1 += 1 found_doc.save() - doc._previous_revision_id = "wrong" + doc.revision_id = "wrong" doc.num_1 = 4 with pytest.raises(RevisionIdWasChanged): doc.save() @@ -122,7 +126,7 @@ def test_update_bulk_writer(): with BulkWriter() as bulk_writer: found_doc.save(bulk_writer=bulk_writer) - doc._previous_revision_id = "wrong" + doc.revision_id = "wrong" doc.num_1 = 4 with pytest.raises(BulkWriteError): with BulkWriter() as bulk_writer: @@ -144,11 +148,21 @@ def test_save_changes_when_there_were_no_changes(): doc = DocumentWithRevisionTurnedOn(num_1=1, num_2=2) doc.insert() revision = doc.revision_id - old_revision = doc._previous_revision_id doc.save_changes() assert doc.revision_id == revision - assert doc._previous_revision_id == old_revision - doc = DocumentWithRevisionTurnedOn.get(doc.id).run() - assert doc._previous_revision_id == old_revision + DocumentWithRevisionTurnedOn.get(doc.id).run() + assert doc.revision_id == revision + + +def test_revision_id_for_link(): + lock = LockWithRevision(k=1) + lock.insert() + + lock_rev_id = lock.revision_id + + window = WindowWithRevision(x=0, y=0, lock=lock) + + window.insert() + assert lock.revision_id == lock_rev_id diff --git a/tests/odm/documents/test_sync.py b/tests/odm/documents/test_sync.py new file mode 100644 index 0000000..d37796d --- /dev/null +++ b/tests/odm/documents/test_sync.py @@ -0,0 +1,54 @@ +import pytest + +from bunnet.exceptions import ApplyChangesException +from bunnet.odm.documents import MergeStrategy +from tests.odm.models import DocumentToTestSync + + +class TestSync: + def test_merge_remote(self): + doc = DocumentToTestSync() + doc.insert() + + doc2 = DocumentToTestSync.get(doc.id).run() + doc2.s = "foo" + + doc.i = 100 + doc.save() + + doc2.sync() + + assert doc2.s == "TEST" + assert doc2.i == 100 + + def test_merge_local(self): + doc = DocumentToTestSync(d={"option_1": {"s": "foo"}}) + doc.insert() + + doc2 = DocumentToTestSync.get(doc.id).run() + doc2.s = "foo" + doc2.n.option_1.s = "bar" + doc2.d["option_1"]["s"] = "bar" + + doc.i = 100 + doc.save() + + doc2.sync(merge_strategy=MergeStrategy.local) + + assert doc2.s == "foo" + assert doc2.n.option_1.s == "bar" + assert doc2.d["option_1"]["s"] == "bar" + + assert doc2.i == 100 + + def test_merge_local_impossible_apply_changes(self): + doc = DocumentToTestSync(d={"option_1": {"s": "foo"}}) + doc.insert() + + doc2 = DocumentToTestSync.get(doc.id).run() + doc2.d["option_1"]["s"] = {"foo": "bar"} + + doc.d = {"option_1": "nothing"} + doc.save() + with pytest.raises(ApplyChangesException): + doc2.sync(merge_strategy=MergeStrategy.local) diff --git a/tests/odm/documents/test_update.py b/tests/odm/documents/test_update.py index c28231c..3790748 100644 --- a/tests/odm/documents/test_update.py +++ b/tests/odm/documents/test_update.py @@ -63,6 +63,8 @@ def test_replace(document): new_doc = document.model_copy(update=update_data) else: new_doc = document.copy(update=update_data) + # pydantic v1 doesn't copy excluded fields + new_doc.test_list = document.test_list # document.test_str = "REPLACED_VALUE" new_doc.replace() new_document = DocumentTestModel.get(document.id).run() diff --git a/tests/odm/documents/test_validation_on_save.py b/tests/odm/documents/test_validation_on_save.py index 34bcd60..c6e3f23 100644 --- a/tests/odm/documents/test_validation_on_save.py +++ b/tests/odm/documents/test_validation_on_save.py @@ -1,6 +1,11 @@ +from typing import Optional + import pytest -from pydantic import ValidationError +from bson import ObjectId +from pydantic import BaseModel, ValidationError +from bunnet import PydanticObjectId +from bunnet.odm.utils.pydantic import IS_PYDANTIC_V2 from tests.odm.models import ( DocumentWithValidationOnSave, Lock, @@ -31,6 +36,28 @@ def test_validate_on_save_changes(): doc.save_changes() +def test_validate_on_save_keep_the_id_type(): + class UpdateModel(BaseModel): + num_1: Optional[int] = None + related: Optional[PydanticObjectId] = None + + doc = DocumentWithValidationOnSave(num_1=1, num_2=2) + doc.insert() + update = UpdateModel(related=PydanticObjectId()) + if IS_PYDANTIC_V2: + doc = doc.model_copy(update=update.model_dump(exclude_unset=True)) + else: + doc = doc.copy(update=update.dict(exclude_unset=True)) + doc.num_2 = 1000 + doc.save() + in_db = DocumentWithValidationOnSave.get_motor_collection().find_one( + {"_id": doc.id} + ) + assert isinstance(in_db["related"], ObjectId) + new_doc = DocumentWithValidationOnSave.get(doc.id).run() + assert isinstance(new_doc.related, PydanticObjectId) + + def test_validate_on_save_action(): doc = DocumentWithValidationOnSave(num_1=1, num_2=2) doc.insert() diff --git a/tests/odm/models.py b/tests/odm/models.py index 1ff404b..66d969f 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -24,8 +24,10 @@ import pymongo from pydantic import ( + UUID4, BaseModel, ConfigDict, + EmailStr, Field, HttpUrl, PrivateAttr, @@ -35,6 +37,7 @@ from pydantic.fields import FieldInfo from pydantic_core import core_schema from pymongo import IndexModel +from typing_extensions import Annotated from bunnet import ( DecimalAnnotation, @@ -146,13 +149,7 @@ class DocumentTestModel(Document): test_int: int test_doc: SubDocument test_str: str - - if IS_PYDANTIC_V2: - test_list: List[SubDocument] = Field( - json_schema_extra={"hidden": True} - ) - else: - test_list: List[SubDocument] = Field(hidden=True) + test_list: List[SubDocument] = Field(exclude=True) class Settings: use_cache = True @@ -199,6 +196,17 @@ class DocumentTestModelWithIndexFlagsAliases(Document): ) +class DocumentTestModelIndexFlagsAnnotated(Document): + str_index: Indexed(str, index_type=pymongo.TEXT) + str_index_annotated: Indexed(str, index_type=pymongo.ASCENDING) + uuid_index_annotated: Annotated[UUID4, Indexed(unique=True)] + + if not IS_PYDANTIC_V2: + # The UUID4 type raises a ValueError with the current + # implementation of Indexed when using Pydantic v2. + uuid_index: Indexed(UUID4, unique=True) + + class DocumentTestModelWithComplexIndex(Document): test_int: int test_list: List[SubDocument] @@ -242,6 +250,13 @@ class Settings: name = "DocumentTestModel" +class DocumentWithDeprecatedHiddenField(Document): + if IS_PYDANTIC_V2: + test_hidden: List[str] = Field(json_schema_extra={"hidden": True}) + else: + test_hidden: List[str] = Field(hidden=True) + + class DocumentWithCustomIdUUID(Document): id: UUID = Field(default_factory=uuid4) name: str @@ -268,6 +283,9 @@ class DocumentWithCustomFiledsTypes(Document): tuple_type: Tuple[int, str] path: Path + class Settings: + bson_encoders = {Color: vars} + if IS_PYDANTIC_V2: model_config = ConfigDict( arbitrary_types_allowed=True, @@ -449,6 +467,7 @@ class DocumentWithTurnedOffStateManagement(Document): class DocumentWithValidationOnSave(Document): num_1: int num_2: int + related: PydanticObjectId = Field(default_factory=PydanticObjectId) @after_event(ValidateOnSave) def num_2_plus_1(self): @@ -534,10 +553,7 @@ class House(Document): roof: Optional[Link[Roof]] = None yards: Optional[List[Link[Yard]]] = None height: Indexed(int) = 2 - if IS_PYDANTIC_V2: - name: Indexed(str) = Field(json_schema_extra={"hidden": True}) - else: - name: Indexed(str) = Field(hidden=True) + name: Indexed(str) = Field(exclude=True) if IS_PYDANTIC_V2: model_config = ConfigDict( @@ -568,19 +584,6 @@ class DocumentWithStringField(Document): class DocumentForEncodingTestDate(Document): date_field: datetime.date = Field(default_factory=datetime.date.today) - class Settings: - name = "test_date" - bson_encoders = { - datetime.date: lambda dt: datetime.datetime( - year=dt.year, - month=dt.month, - day=dt.day, - hour=0, - minute=0, - second=0, - ) - } - class DocumentUnion(UnionDoc): class Settings: @@ -809,13 +812,21 @@ class SelfLinked(Document): item: Optional[Link["SelfLinked"]] = None s: str + class Settings: + max_nesting_depth = 2 + class LoopedLinksA(Document): - b: "LoopedLinksB" + b: Link["LoopedLinksB"] + s: str + + class Settings: + max_nesting_depths_per_field = {"b": 2} class LoopedLinksB(Document): - a: Optional[LoopedLinksA] = None + a: Optional[Link[LoopedLinksA]] = None + s: str class DocWithCollectionInnerClass(Document): @@ -1037,3 +1048,56 @@ def foo(self, bar: str) -> None: class DocumentWithHttpUrlField(Document): url_field: HttpUrl + + +class DocumentWithComplexDictKey(Document): + dict_field: Dict[UUID, datetime.datetime] + + +class DocumentWithIndexedObjectId(Document): + pyid: Indexed(PydanticObjectId) + uuid: Annotated[UUID4, Indexed(unique=True)] + email: Annotated[EmailStr, Indexed(unique=True)] + + +class DocumentToTestSync(Document): + s: str = "TEST" + i: int = 1 + n: Nested = Nested( + integer=1, option_1=Option1(s="test"), union=Option1(s="test") + ) + o: Optional[Option2] = None + d: Dict[str, Any] = {} + + class Settings: + use_state_management = True + + +class DocumentWithLinkForNesting(Document): + link: Link["DocumentWithBackLinkForNesting"] + s: str + + class Settings: + max_nesting_depths_per_field = {"link": 0} + + +class DocumentWithBackLinkForNesting(Document): + if IS_PYDANTIC_V2: + back_link: BackLink[DocumentWithLinkForNesting] = Field( + json_schema_extra={"original_field": "link"}, + ) + else: + back_link: BackLink[DocumentWithLinkForNesting] = Field( + original_field="link" + ) + i: int + + class Settings: + max_nesting_depths_per_field = {"back_link": 5} + + +class LongSelfLink(Document): + link: Optional[Link["LongSelfLink"]] = None + + class Settings: + max_nesting_depth = 50 diff --git a/tests/odm/query/test_aggregate.py b/tests/odm/query/test_aggregate.py index 2af38b6..65dbde3 100644 --- a/tests/odm/query/test_aggregate.py +++ b/tests/odm/query/test_aggregate.py @@ -3,7 +3,9 @@ from pydantic.main import BaseModel from pymongo.errors import OperationFailure -from tests.odm.models import Sample +from bunnet.odm.enums import SortDirection +from bunnet.odm.utils.find import construct_lookup_queries +from tests.odm.models import DocumentWithTextIndexAndLink, Sample def test_aggregate(preset_documents): @@ -36,7 +38,37 @@ def test_aggregate_with_filter(preset_documents): assert {"_id": "test_3", "total": 3} in result -def test_aggregate_with_projection_model(preset_documents): +def test_aggregate_with_sort_skip(preset_documents): + q = Sample.find(sort="_id", skip=2).aggregate( + [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] + ) + assert q.get_aggregation_pipeline() == [ + {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, + {"$sort": {"_id": SortDirection.ASCENDING}}, + {"$skip": 2}, + ] + assert q.to_list() == [ + {"_id": "test_2", "total": 6}, + {"_id": "test_3", "total": 3}, + ] + + +def test_aggregate_with_sort_limit(preset_documents): + q = Sample.find(sort="_id", limit=2).aggregate( + [{"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}] + ) + assert q.get_aggregation_pipeline() == [ + {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, + {"$sort": {"_id": SortDirection.ASCENDING}}, + {"$limit": 2}, + ] + assert q.to_list() == [ + {"_id": "test_0", "total": 0}, + {"_id": "test_1", "total": 3}, + ] + + +async def test_aggregate_with_projection_model(preset_documents): class OutputItem(BaseModel): id: str = Field(None, alias="_id") total: int @@ -107,3 +139,50 @@ def test_clone(preset_documents): {"$group": {"_id": "$string", "total": {"$sum": "$integer"}}}, {"a": "b"}, ] + + +@pytest.mark.parametrize("text_query_count", [0, 1, 2]) +@pytest.mark.parametrize("non_text_query_count", [0, 1, 2]) +def test_with_text_queries(text_query_count: int, non_text_query_count: int): + text_query = {"$text": {"$search": "text_search"}} + non_text_query = {"s": "test_string"} + aggregation_pipeline = [{"$count": "count"}] + queries = [] + + if text_query_count: + queries.append(text_query) + if text_query_count > 1: + queries.append(text_query) + + if non_text_query_count: + queries.append(non_text_query) + if non_text_query_count > 1: + queries.append(non_text_query) + + query = DocumentWithTextIndexAndLink.find(*queries, fetch_links=True) + + expected_aggregation_pipeline = [] + if text_query_count: + expected_aggregation_pipeline.append( + {"$match": text_query} + if text_query_count == 1 + else {"$match": {"$and": [text_query, text_query]}} + ) + + expected_aggregation_pipeline.extend( + construct_lookup_queries(query.document_model) + ) + + if non_text_query_count: + expected_aggregation_pipeline.append( + {"$match": non_text_query} + if non_text_query_count == 1 + else {"$match": {"$and": [non_text_query, non_text_query]}} + ) + + expected_aggregation_pipeline.extend(aggregation_pipeline) + + assert ( + query.build_aggregation_pipeline(*aggregation_pipeline) + == expected_aggregation_pipeline + ) diff --git a/tests/odm/query/test_delete.py b/tests/odm/query/test_delete.py index 05779a5..f8b1748 100644 --- a/tests/odm/query/test_delete.py +++ b/tests/odm/query/test_delete.py @@ -88,8 +88,15 @@ def test_delete_many_with_session(preset_documents, session): q = ( Sample.find_many(Sample.integer > 1) .find_many(Sample.nested.optional == None) - .set_session(session=session) + .delete(session=session) + ) # noqa + assert q.session == session + + q = ( + Sample.find_many(Sample.integer > 1) + .find_many(Sample.nested.optional == None) .delete() + .set_session(session=session) .run() ) # noqa diff --git a/tests/odm/query/test_find.py b/tests/odm/query/test_find.py index 7d8d929..4fa0deb 100644 --- a/tests/odm/query/test_find.py +++ b/tests/odm/query/test_find.py @@ -1,4 +1,5 @@ import datetime +from enum import Enum import pytest from pydantic import BaseModel @@ -398,3 +399,19 @@ def test_find_clone(): ("string", SortDirection.ASCENDING), ] assert new_q.limit_number == 10 + + +def test_find_many_with_enum_in_query(preset_documents): + class TestEnum(str, Enum): + INTEGER = Sample.integer + SAMPLE_NESTED_OPTIONAL = Sample.nested.optional + CONST = "const" + CONST_VALUE = "TEST" + + filter_query = { + TestEnum.INTEGER: {"$gt": 1}, + TestEnum.SAMPLE_NESTED_OPTIONAL: {"$type": "null"}, + TestEnum.CONST: TestEnum.CONST_VALUE, + } + result = Sample.find_many(filter_query).to_list() + assert len(result) == 2 diff --git a/tests/odm/test_cursor.py b/tests/odm/test_cursor.py index 4ed8e36..3232823 100644 --- a/tests/odm/test_cursor.py +++ b/tests/odm/test_cursor.py @@ -7,7 +7,7 @@ def test_to_list(documents): assert len(result) == 10 -def test__for(documents): +def test_async_for(documents): documents(10) for document in DocumentTestModel.find_all(): assert document.test_int in list(range(10)) diff --git a/tests/odm/test_encoder.py b/tests/odm/test_encoder.py index 45ab368..8cc9d0b 100644 --- a/tests/odm/test_encoder.py +++ b/tests/odm/test_encoder.py @@ -1,5 +1,6 @@ import re from datetime import date, datetime +from uuid import uuid4 import pytest from bson import Binary, Regex @@ -11,6 +12,7 @@ Child, DocumentForEncodingTest, DocumentForEncodingTestDate, + DocumentWithComplexDictKey, DocumentWithDecimalField, DocumentWithHttpUrlField, DocumentWithKeepNullsFalse, @@ -152,3 +154,18 @@ def test_should_be_able_to_save_retrieve_doc_with_url(): assert isinstance(new_doc.url_field, AnyUrl) assert new_doc.url_field == doc.url_field + + +def test_dict_with_complex_key(): + assert isinstance(Encoder().encode({uuid4(): datetime.now()}), dict) + + uuid = uuid4() + # reset microseconds, because it looses by mongo + dt = datetime.now().replace(microsecond=0) + + doc = DocumentWithComplexDictKey(dict_field={uuid: dt}) + doc.insert() + new_doc = DocumentWithComplexDictKey.get(doc.id).run() + + assert isinstance(new_doc.dict_field, dict) + assert new_doc.dict_field.get(uuid) == dt diff --git a/tests/odm/test_fields.py b/tests/odm/test_fields.py index 23838c2..fa00e66 100644 --- a/tests/odm/test_fields.py +++ b/tests/odm/test_fields.py @@ -2,6 +2,7 @@ from decimal import Decimal from pathlib import Path from typing import AbstractSet, Mapping +from uuid import uuid4 import pytest from pydantic import BaseModel, ValidationError @@ -14,8 +15,10 @@ from bunnet.odm.utils.pydantic import IS_PYDANTIC_V2 from tests.odm.models import ( DocumentTestModel, + DocumentTestModelIndexFlagsAnnotated, DocumentWithBsonEncodersFiledsTypes, DocumentWithCustomFiledsTypes, + DocumentWithDeprecatedHiddenField, Sample, ) @@ -106,7 +109,7 @@ def test_custom_filed_types(): ) -def test_hidden(document): +def test_excluded(document): document = DocumentTestModel.find_one().run() if IS_PYDANTIC_V2: assert "test_list" not in document.model_dump() @@ -114,6 +117,16 @@ def test_hidden(document): assert "test_list" not in document.dict() +def test_hidden(): + document = DocumentWithDeprecatedHiddenField(test_hidden=["abc", "def"]) + document.insert() + document = DocumentWithDeprecatedHiddenField.find_one().run() + if IS_PYDANTIC_V2: + assert "test_hidden" not in document.model_dump() + else: + assert "test_hidden" not in document.dict() + + def test_revision_id_not_in_schema(): """Check if there is a `revision_id` slipping into the schema.""" @@ -156,3 +169,17 @@ def test_param_exclude(document, exclude): def test_expression_fields(): assert Sample.nested.integer == "nested.integer" assert Sample.nested["integer"] == "nested.integer" + + +def test_indexed_field() -> None: + """Test that fields can be declared and instantiated with Indexed() + and Annotated[..., Indexed()].""" + + # No error should be raised the document is properly initialized + # and `Indexed` is implemented correctly. + DocumentTestModelIndexFlagsAnnotated( + str_index="test", + str_index_annotated="test", + uuid_index=uuid4(), + uuid_index_annotated=uuid4(), + ) diff --git a/tests/odm/test_relations.py b/tests/odm/test_relations.py index ce66dee..c021361 100644 --- a/tests/odm/test_relations.py +++ b/tests/odm/test_relations.py @@ -11,13 +11,16 @@ get_model_fields, parse_model, ) +from bunnet.operators import In, Or from tests.odm.models import ( AddressView, ADocument, BDocument, DocumentToBeLinked, DocumentWithBackLink, + DocumentWithBackLinkForNesting, DocumentWithLink, + DocumentWithLinkForNesting, DocumentWithListBackLink, DocumentWithListLink, DocumentWithListOfLinks, @@ -26,6 +29,7 @@ House, LinkDocumentForTextSeacrh, Lock, + LongSelfLink, LoopedLinksA, LoopedLinksB, Region, @@ -172,12 +176,18 @@ def test_multi_insert_links(self): house.windows.append(window) house = house.insert(link_rule=WriteRules.WRITE) - new_window = Window(x=11, y=22) - house.windows.append(new_window) + new_window_1 = Window(x=11, y=22) + assert new_window_1.id is None + house.windows.append(new_window_1) + new_window_2 = Window(x=12, y=23) + assert new_window_2.id is None + house.windows.append(new_window_2) house.save(link_rule=WriteRules.WRITE) for win in house.windows: assert isinstance(win, Window) assert win.id + assert new_window_1.id is not None + assert new_window_2.id is not None def test_fetch_after_insert(self, house_not_inserted): house_not_inserted.fetch_all_links() @@ -309,7 +319,32 @@ def test_find_by_id_of_the_linked_docs(self, house): assert house_1 is not None assert house_2 is not None - def test_fetch_list_with_some_prefetched(self): + def test_find_by_id_list_of_the_linked_docs(self, houses): + items = ( + House.find(House.height < 3, fetch_links=True) + .sort(House.height) + .to_list() + ) + assert len(items) == 3 + + house_lst_1 = House.find( + Or( + House.door.id == items[0].door.id, + In(House.door.id, [items[1].door.id, items[2].door.id]), + ) + ).to_list() + house_lst_2 = House.find( + Or( + House.door.id == items[0].door.id, + In(House.door.id, [items[1].door.id, items[2].door.id]), + ), + fetch_links=True, + ).to_list() + + assert len(house_lst_1) == 3 + assert len(house_lst_2) == 3 + + async def test_fetch_list_with_some_prefetched(self): docs = [] for i in range(10): doc = DocumentToBeLinked() @@ -351,6 +386,37 @@ def test_text_search(self): ).to_list() assert len(docs) == 1 + def test_self_nesting_find_parameters(self): + self_linked_doc = LongSelfLink() + self_linked_doc.insert(link_rule=WriteRules.WRITE) + self_linked_doc.link = self_linked_doc + self_linked_doc.save() + + self_linked_doc = LongSelfLink.find_one( + nesting_depth=4, fetch_links=True + ).run() + assert self_linked_doc.link.link.link.link.id == self_linked_doc.id + assert isinstance(self_linked_doc.link.link.link.link.link, Link) + + self_linked_doc = LongSelfLink.find_one( + nesting_depth=0, fetch_links=True + ).run() + assert isinstance(self_linked_doc.link, Link) + + def test_nesting_find_parameters(self): + back_link_doc = DocumentWithBackLinkForNesting(i=1) + back_link_doc.insert() + link_doc = DocumentWithLinkForNesting(link=back_link_doc, s="TEST") + link_doc.insert() + + doc = DocumentWithBackLinkForNesting.find_one( + DocumentWithBackLinkForNesting.i == 1, + fetch_links=True, + nesting_depths_per_field={"back_link": 2}, + ).run() + assert doc.back_link.link.id == doc.id + assert isinstance(doc.back_link.link.back_link, BackLink) + class TestReplace: def test_do_nothing(self, house): @@ -375,15 +441,20 @@ def test_do_nothing(self, house): def test_write(self, house): house.door.t = 100 - house.windows = [Window(x=100, y=100, lock=Lock(k=100))] + new_window = Window(x=100, y=100, lock=Lock(k=100)) + house.windows = [new_window] + assert new_window.id is None house.save(link_rule=WriteRules.WRITE) new_house = House.get(house.id, fetch_links=True).run() assert new_house.door.t == 100 for window in new_house.windows: assert window.x == 100 assert window.y == 100 + assert window.id is not None assert isinstance(window.lock, Lock) assert window.lock.k == 100 + assert window.lock.id is not None + assert new_window.id is not None class TestDelete: @@ -475,17 +546,38 @@ def test_self_linked(self): assert isinstance(res.item.item.item, Link) def test_looped_links(self): - LoopedLinksA(b=LoopedLinksB(a=LoopedLinksA(b=LoopedLinksB()))).insert( - link_rule=WriteRules.WRITE - ) - res = LoopedLinksA.find_one(fetch_links=True).run() + LoopedLinksA( + b=LoopedLinksB( + a=LoopedLinksA( + b=LoopedLinksB( + s="4", + ), + s="3", + ), + s="2", + ), + s="1", + ).insert(link_rule=WriteRules.WRITE) + res = LoopedLinksA.find_one( + LoopedLinksA.s == "1", fetch_links=True + ).run() assert isinstance(res, LoopedLinksA) assert isinstance(res.b, LoopedLinksB) assert isinstance(res.b.a, LoopedLinksA) - assert isinstance(res.b.a.b, LoopedLinksB) - assert res.b.a.b.a is None + assert isinstance(res.b.a.b, Link) - def test_with_chaining_aggregation(self): + LoopedLinksA( + b=LoopedLinksB(s="a2"), + s="a1", + ).insert(link_rule=WriteRules.WRITE) + res = LoopedLinksA.find_one( + LoopedLinksA.s == "a1", fetch_links=True + ).run() + assert isinstance(res, LoopedLinksA) + assert isinstance(res.b, LoopedLinksB) + assert res.b.a is None + + async def test_with_chaining_aggregation(self): region = Region() region.insert() @@ -508,6 +600,45 @@ def test_with_chaining_aggregation(self): assert addresses_count[0] == {"count": 10} + def test_with_chaining_aggregation_and_text_search(self): + # ARRANGE + NUM_DOCS = 10 + NUM_WITH_LOWER = 5 + linked_document = LinkDocumentForTextSeacrh(i=1) + linked_document.insert() + + for i in range(NUM_DOCS): + DocumentWithTextIndexAndLink( + s="lower" if i < NUM_WITH_LOWER else "UPPER", + link=linked_document, + ).insert() + + linked_document_2 = LinkDocumentForTextSeacrh(i=2) + linked_document_2.insert() + + for i in range(NUM_DOCS): + DocumentWithTextIndexAndLink( + s="lower" if i < NUM_WITH_LOWER else "UPPER", + link=linked_document_2, + ).insert() + + # ACT + query = DocumentWithTextIndexAndLink.find( + {"$text": {"$search": "lower"}}, + DocumentWithTextIndexAndLink.link.i == 1, + fetch_links=True, + ) + + # Test both aggregation and count methods + document_count_aggregation = query.aggregate( + [{"$count": "count"}] + ).to_list() + document_count = query.count() + + # ASSERT + assert document_count_aggregation[0] == {"count": NUM_WITH_LOWER} + assert document_count == NUM_WITH_LOWER + def test_with_extra_allow(self, houses): res = House.find(fetch_links=True).to_list() assert get_model_fields(res[0]).keys() == { @@ -569,6 +700,30 @@ def test_prefetch_list(self, list_link_and_list_backlink_doc_pair): assert back_link_doc.back_link[0].id == link_doc.id assert back_link_doc.back_link[0].link[0].id == back_link_doc.id + def test_nesting(self): + back_link_doc = DocumentWithBackLinkForNesting(i=1) + back_link_doc.insert() + link_doc = DocumentWithLinkForNesting(link=back_link_doc, s="TEST") + link_doc.insert() + + doc = DocumentWithLinkForNesting.get( + link_doc.id, fetch_links=True + ).run() + assert isinstance(doc.link, Link) + doc.link = doc.link.fetch() + assert doc.link.i == 1 + + back_link_doc = DocumentWithBackLinkForNesting.get( + back_link_doc.id, fetch_links=True + ).run() + assert ( + back_link_doc.back_link.link.back_link.link.back_link.id + == link_doc.id + ) + assert isinstance( + back_link_doc.back_link.link.back_link.link.back_link.link, Link + ) + class TestReplaceBackLinks: def test_do_nothing(self, link_and_backlink_doc_pair): @@ -757,3 +912,36 @@ def test_init_reversed_order(self, db): PersonForReversedOrderInit, ], ) + + +class TestBuildAggregations: + def test_find_aggregate_without_fetch_links(self, houses): + door = Door.find_one().run() + aggregation = House.find(House.door.id == door.id).aggregate( + [ + {"$group": {"_id": "$height", "count": {"$sum": 1}}}, + ] + ) + assert aggregation.get_aggregation_pipeline() == [ + {"$match": {"door.$id": door.id}}, + {"$group": {"_id": "$height", "count": {"$sum": 1}}}, + ] + result = aggregation.to_list() + assert result == [{"_id": 0, "count": 1}] + + def test_find_aggregate_with_fetch_links(self, houses): + door = Door.find_one().run() + aggregation = House.find( + House.door.id == door.id, fetch_links=True + ).aggregate( + [ + {"$group": {"_id": "$height", "count": {"$sum": 1}}}, + ] + ) + assert len(aggregation.get_aggregation_pipeline()) == 12 + assert aggregation.get_aggregation_pipeline()[10:] == [ + {"$match": {"door._id": door.id}}, + {"$group": {"_id": "$height", "count": {"$sum": 1}}}, + ] + result = aggregation.to_list() + assert result == [{"_id": 0, "count": 1}] diff --git a/tests/odm/test_state_management.py b/tests/odm/test_state_management.py index 59337f2..3bb7f62 100644 --- a/tests/odm/test_state_management.py +++ b/tests/odm/test_state_management.py @@ -4,7 +4,7 @@ from bunnet import PydanticObjectId, WriteRules from bunnet.exceptions import StateManagementIsTurnedOff, StateNotSaved from bunnet.odm.utils.parsing import parse_obj -from bunnet.odm.utils.pydantic import IS_PYDANTIC_V2 +from bunnet.odm.utils.pydantic import IS_PYDANTIC_V2, parse_model from tests.odm.models import ( DocumentWithTurnedOffStateManagement, DocumentWithTurnedOnReplaceObjects, @@ -407,8 +407,8 @@ def test_find_many(self): assert doc.get_previous_saved_state() is None def test_insert(self, state_without_id): - doc = DocumentWithTurnedOnStateManagement.parse_obj( - state_without_id + doc = parse_model( + DocumentWithTurnedOnStateManagement, state_without_id ) assert doc.get_saved_state() is None doc.insert() @@ -430,3 +430,33 @@ def test_replace_save_previous(self, saved_doc_previous): assert saved_doc_previous.get_saved_state()["num_1"] == 100 assert saved_doc_previous.get_previous_saved_state()["num_1"] == 1 + + def test_exclude_revision_id(self, saved_doc_previous): + saved_doc_previous.num_1 = 100 + saved_doc_previous.replace() + + assert saved_doc_previous.get_saved_state()["num_1"] == 100 + assert saved_doc_previous.get_previous_saved_state()["num_1"] == 1 + + assert ( + saved_doc_previous.get_saved_state().get("revision_id") is None + ) + assert ( + saved_doc_previous.get_saved_state().get( + "previous_revision_id" + ) + is None + ) + + assert ( + saved_doc_previous.get_previous_saved_state().get( + "revision_id" + ) + is None + ) + assert ( + saved_doc_previous.get_previous_saved_state().get( + "previous_revision_id" + ) + is None + ) diff --git a/tests/odm/test_typing_utils.py b/tests/odm/test_typing_utils.py index 7e78576..6175dbc 100644 --- a/tests/odm/test_typing_utils.py +++ b/tests/odm/test_typing_utils.py @@ -1,7 +1,13 @@ from typing import Optional, Union +import pytest +from pydantic import BaseModel +from typing_extensions import Annotated + from bunnet import Document, Link -from bunnet.odm.utils.typing import extract_id_class +from bunnet.odm.fields import Indexed +from bunnet.odm.utils.pydantic import get_model_fields +from bunnet.odm.utils.typing import extract_id_class, get_index_attributes class Lock(Document): @@ -18,3 +24,24 @@ def test_extract_id_class(self): assert extract_id_class(Optional[str]) == str # Link assert extract_id_class(Link[Lock]) == Lock + + @pytest.mark.parametrize( + "type,result", + ( + (str, None), + (Indexed(str), (1, {})), + (Indexed(str, "text", unique=True), ("text", {"unique": True})), + (Annotated[str, Indexed()], (1, {})), + ( + Annotated[str, "other metadata", Indexed(unique=True)], + (1, {"unique": True}), + ), + (Annotated[str, "other metadata"], None), + ), + ) + def test_get_index_attributes(self, type, result): + class Foo(BaseModel): + bar: type + + field = get_model_fields(Foo)["bar"] + assert get_index_attributes(field) == result