Skip to content

Commit

Permalink
Sync with beanie 24 02 (#20)
Browse files Browse the repository at this point in the history
* the logic is synced

* tests synced

* version: 1.3.0

* fix tests
  • Loading branch information
roman-right authored Feb 28, 2024
1 parent a7acb0e commit 5c4da0a
Show file tree
Hide file tree
Showing 48 changed files with 1,395 additions and 426 deletions.
9 changes: 7 additions & 2 deletions bunnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
from bunnet.odm.bulk import BulkWriter
from bunnet.odm.custom_types import DecimalAnnotation
from bunnet.odm.custom_types.bson.binary import BsonBinary
from bunnet.odm.documents import Document
from bunnet.odm.documents import Document, MergeStrategy
from bunnet.odm.enums import SortDirection
from bunnet.odm.fields import (
BackLink,
BunnetObjectId,
DeleteRules,
Indexed,
Link,
Expand All @@ -31,17 +33,20 @@
from bunnet.odm.utils.init import init_bunnet
from bunnet.odm.views import View

__version__ = "1.2.0"
__version__ = "1.3.0"
__all__ = [
# ODM
"Document",
"View",
"UnionDoc",
"init_bunnet",
"PydanticObjectId",
"BunnetObjectId",
"Indexed",
"TimeSeriesConfig",
"Granularity",
"SortDirection",
"MergeStrategy",
# Actions
"before_event",
"after_event",
Expand Down
4 changes: 4 additions & 0 deletions bunnet/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,7 @@ class DocWasNotRegisteredInUnionClass(Exception):

class Deprecation(Exception):
pass


class ApplyChangesException(Exception):
pass
18 changes: 17 additions & 1 deletion bunnet/executors/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, **kwargs):
or self.get_from_toml("allow_index_dropping")
or False
)
self.use_transaction = bool(kwargs.get("use_transaction"))

@staticmethod
def get_env_value(field_name) -> Any:
Expand Down Expand Up @@ -109,7 +110,11 @@ def run_migrate(settings: MigrationSettings):
mode = RunningMode(
direction=settings.direction, distance=settings.distance
)
root.run(mode=mode, allow_index_dropping=settings.allow_index_dropping)
root.run(
mode=mode,
allow_index_dropping=settings.allow_index_dropping,
use_transaction=settings.use_transaction,
)


@migrations.command()
Expand Down Expand Up @@ -157,13 +162,23 @@ def run_migrate(settings: MigrationSettings):
default=False,
help="if allow-index-dropping is set, Beanie will drop indexes from your collection",
)
@click.option(
"--use-transaction/--no-use-transaction",
required=False,
default=True,
help="Enable or disable the use of transactions during migration. "
"When enabled (--use-transaction), Bunnet uses transactions for migration, "
"which necessitates a replica set. When disabled (--no-use-transaction), "
"migrations occur without transactions.",
)
def migrate(
direction,
distance,
connection_uri,
database_name,
path,
allow_index_dropping,
use_transaction,
):
settings_kwargs = {}
if direction:
Expand All @@ -178,6 +193,7 @@ def migrate(
settings_kwargs["path"] = path
if allow_index_dropping:
settings_kwargs["allow_index_dropping"] = allow_index_dropping
settings_kwargs["use_transaction"] = use_transaction
settings = MigrationSettings(**settings_kwargs)

run_migrate(settings)
Expand Down
3 changes: 3 additions & 0 deletions bunnet/migrations/controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@


class BaseMigrationController(ABC):
def __init__(self, function):
self.function = function

@abstractmethod
def run(self, session):
pass
Expand Down
4 changes: 3 additions & 1 deletion bunnet/migrations/controllers/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ def models(self) -> List[Type[Document]]:

def run(self, session):
output_documents = []
for input_document in self.input_document_model.find_all():
for input_document in self.input_document_model.find_all(
session=session
):
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
Expand Down
78 changes: 55 additions & 23 deletions bunnet/migrations/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
from importlib.machinery import SourceFileLoader
from pathlib import Path
from typing import Optional, Type
from typing import List, Optional, Type

from pymongo.client_session import ClientSession
from pymongo.database import Database

from bunnet.migrations.controllers.iterative import BaseMigrationController
from bunnet.migrations.database import DBHandler
Expand Down Expand Up @@ -55,7 +58,12 @@ def update_current_migration(self):
self.clean_current_migration()
MigrationLog(is_current=True, name=self.name).insert()

def run(self, mode: RunningMode, allow_index_dropping: bool):
def run(
self,
mode: RunningMode,
allow_index_dropping: bool,
use_transaction: bool,
):
"""
Migrate
Expand All @@ -71,7 +79,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info("Running migrations forward without limit")
while True:
migration_node.run_forward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.next_migration
if migration_node is None:
Expand All @@ -80,7 +89,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info(f"Running {mode.distance} migrations forward")
for i in range(mode.distance):
migration_node.run_forward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.next_migration
if migration_node is None:
Expand All @@ -91,7 +101,8 @@ def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info("Running migrations backward without limit")
while True:
migration_node.run_backward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.prev_migration
if migration_node is None:
Expand All @@ -100,30 +111,37 @@ def run(self, mode: RunningMode, allow_index_dropping: bool):
logger.info(f"Running {mode.distance} migrations backward")
for i in range(mode.distance):
migration_node.run_backward(
allow_index_dropping=allow_index_dropping
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
migration_node = migration_node.prev_migration
if migration_node is None:
break

def run_forward(self, allow_index_dropping):
def run_forward(self, allow_index_dropping: bool, use_transaction: bool):
if self.forward_class is not None:
self.run_migration_class(
self.forward_class, allow_index_dropping=allow_index_dropping
self.forward_class,
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
self.update_current_migration()

def run_backward(self, allow_index_dropping):
def run_backward(self, allow_index_dropping: bool, use_transaction: bool):
if self.backward_class is not None:
self.run_migration_class(
self.backward_class, allow_index_dropping=allow_index_dropping
self.backward_class,
allow_index_dropping=allow_index_dropping,
use_transaction=use_transaction,
)
if self.prev_migration is not None:
self.prev_migration.update_current_migration()
else:
self.clean_current_migration()

def run_migration_class(self, cls: Type, allow_index_dropping: bool):
def run_migration_class(
self, cls: Type, allow_index_dropping: bool, use_transaction: bool
):
"""
Run Backward or Forward migration class
Expand All @@ -142,19 +160,33 @@ def run_migration_class(self, cls: Type, allow_index_dropping: bool):
if client is None:
raise RuntimeError("client must not be None")
with client.start_session() as s:
with s.start_transaction():
for migration in migrations:
for model in migration.models:
init_bunnet(
database=db,
document_models=[model], # type: ignore
allow_index_dropping=allow_index_dropping,
) # TODO this is slow
logger.info(
f"Running migration {migration.function.__name__} "
f"from module {self.name}"
if use_transaction:
with s.start_transaction():
self.run_migrations(
migrations, db, allow_index_dropping, s
)
migration.run(session=s)
else:
self.run_migrations(migrations, db, allow_index_dropping, s)

def run_migrations(
self,
migrations: List[BaseMigrationController],
db: Database,
allow_index_dropping: bool,
session: ClientSession,
) -> None:
for migration in migrations:
for model in migration.models:
init_bunnet(
database=db,
document_models=[model], # type: ignore
allow_index_dropping=allow_index_dropping,
) # TODO this is slow
logger.info(
f"Running migration {migration.function.__name__} "
f"from module {self.name}"
)
migration.run(session=session)

@classmethod
def build(cls, path: Path):
Expand Down
Loading

0 comments on commit 5c4da0a

Please sign in to comment.