diff --git a/internal/infrastructure/background_task/celery/task/di.py b/internal/infrastructure/background_task/celery/task/di.py index 2f0b54f0..41d30aae 100644 --- a/internal/infrastructure/background_task/celery/task/di.py +++ b/internal/infrastructure/background_task/celery/task/di.py @@ -5,7 +5,7 @@ from internal.repository.flat import FileRepository from internal.repository.relational.file import DatasetRepository from internal.repository.relational.task import TaskRepository -from internal.uow import UnitOfWork +from internal.infrastructure.data_storage.uow import UnitOfWork from internal.usecase.task.profile_task import ProfileTask from internal.usecase.task.update_task_info import UpdateTaskInfo @@ -23,9 +23,13 @@ def get_task_repo() -> TaskRepository: def get_update_task_info_use_case(): - context_maker = get_postgres_context_maker_without_pool() + postgres_context_maker = get_postgres_context_maker_without_pool() + flat_context_maker = get_flat_context_maker() - unit_of_work = UnitOfWork(context_maker) + unit_of_work = UnitOfWork( + postgres_context_maker=postgres_context_maker, + flat_context_maker=flat_context_maker, + ) task_repo = get_task_repo() return UpdateTaskInfo( @@ -37,15 +41,15 @@ def get_update_task_info_use_case(): def get_profile_task_use_case(): postgres_context_maker = get_postgres_context_maker_without_pool() flat_context_maker = get_flat_context_maker() - - file_unit_of_work = UnitOfWork(flat_context_maker) - dataset_unit_of_work = UnitOfWork(postgres_context_maker) + unit_of_work = UnitOfWork( + postgres_context_maker=postgres_context_maker, + flat_context_maker=flat_context_maker, + ) file_repo = get_file_repo() dataset_repo = get_dataset_repo() return ProfileTask( - file_unit_of_work=file_unit_of_work, - dataset_unit_of_work=dataset_unit_of_work, + unit_of_work=unit_of_work, file_repo=file_repo, # type: ignore dataset_repo=dataset_repo, # type: ignore ) diff --git a/internal/infrastructure/data_storage/flat/__init__.py b/internal/infrastructure/data_storage/flat/__init__.py index d1fb6e59..47da2aa4 100644 --- a/internal/infrastructure/data_storage/flat/__init__.py +++ b/internal/infrastructure/data_storage/flat/__init__.py @@ -3,3 +3,5 @@ FlatContext, get_flat_context_maker, ) + +from internal.infrastructure.data_storage.uow import UnitOfWork # noqa: F401 diff --git a/internal/infrastructure/data_storage/flat/context.py b/internal/infrastructure/data_storage/flat/context.py index 1f8b679f..7f7c5abe 100644 --- a/internal/infrastructure/data_storage/flat/context.py +++ b/internal/infrastructure/data_storage/flat/context.py @@ -1,25 +1,44 @@ +import os from pathlib import Path from internal.infrastructure.data_storage import settings +from internal.uow.exception import NotActiveContextException class FlatContext: def __init__(self, upload_directory_path: Path): self._upload_directory_path = upload_directory_path + self._change: list[Path] | None = [] @property def upload_directory_path(self) -> Path: return self._upload_directory_path - # This context implementation does not support transactions - def flush(self) -> None: ... - - def rollback(self) -> None: ... - - def commit(self) -> None: ... - - def close(self) -> None: ... # TODO: implement flat context closing. + def add(self, file_path: Path) -> None: + if self._change is None: + raise NotActiveContextException() + else: + self._change.append(file_path) + + def rollback(self) -> None: + if self._change is None: + raise NotActiveContextException() + for file_path in self._change: + if os.path.exists(file_path): + os.remove(file_path) + self._change.clear() + + def commit(self) -> None: + # This context implementation does not this method. + # Changes are saved automatically. + # However, in case of rollback they will be deleted. + pass + + def close(self) -> None: + if self._change: + self.rollback() + self._change = None class FlatContextMaker: diff --git a/internal/infrastructure/data_storage/relational/__init__.py b/internal/infrastructure/data_storage/relational/__init__.py index bb68759c..6788b8c2 100644 --- a/internal/infrastructure/data_storage/relational/__init__.py +++ b/internal/infrastructure/data_storage/relational/__init__.py @@ -1,3 +1,4 @@ from internal.infrastructure.data_storage.relational.context import ( # noqa: F401 RelationalContextType, + RelationalContextMakerType, ) diff --git a/internal/infrastructure/data_storage/relational/context.py b/internal/infrastructure/data_storage/relational/context.py index 07284b33..790da11f 100644 --- a/internal/infrastructure/data_storage/relational/context.py +++ b/internal/infrastructure/data_storage/relational/context.py @@ -1,3 +1,5 @@ -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker -RelationalContextType = Session +type RelationalContextType = Session + +type RelationalContextMakerType = sessionmaker diff --git a/internal/infrastructure/data_storage/relational/postgres/context.py b/internal/infrastructure/data_storage/relational/postgres/context.py index 7231e0f4..0c7cfdfe 100644 --- a/internal/infrastructure/data_storage/relational/postgres/context.py +++ b/internal/infrastructure/data_storage/relational/postgres/context.py @@ -14,9 +14,9 @@ PostgresContextMakerWithoutPool = sessionmaker(bind=engine_without_pool) -def get_postgres_context_maker() -> sessionmaker[Session]: +def get_postgres_context_maker() -> sessionmaker: return PostgresContextMaker -def get_postgres_context_maker_without_pool() -> sessionmaker[Session]: +def get_postgres_context_maker_without_pool() -> sessionmaker: return PostgresContextMakerWithoutPool diff --git a/internal/infrastructure/data_storage/universal_context.py b/internal/infrastructure/data_storage/universal_context.py new file mode 100644 index 00000000..91aa50cb --- /dev/null +++ b/internal/infrastructure/data_storage/universal_context.py @@ -0,0 +1,45 @@ +from internal.infrastructure.data_storage.flat import FlatContext, FlatContextMaker +from internal.infrastructure.data_storage.relational import ( + RelationalContextType, + RelationalContextMakerType, +) + + +class UniversalDataStorageContext: + + def __init__( + self, relational_context: RelationalContextType, flat_context: FlatContext + ): + self._relational_context = relational_context + self._flat_context = flat_context + + @property + def relational_context(self) -> RelationalContextType: + return self._relational_context + + @property + def flat_context(self) -> FlatContext: + return self._flat_context + + def commit(self) -> None: + self._relational_context.commit() + self._flat_context.commit() + + def rollback(self) -> None: + self._relational_context.rollback() + self._flat_context.rollback() + + def close(self) -> None: + self._relational_context.close() + self._flat_context.close() + + +class UniversalDataStorageContextMaker: + + def __init__( + self, + relational_context_maker: RelationalContextMakerType, + flat_context_maker: FlatContextMaker, + ): + self.relational_context_maker = relational_context_maker + self.flat_context_maker = flat_context_maker diff --git a/internal/infrastructure/data_storage/uow.py b/internal/infrastructure/data_storage/uow.py new file mode 100644 index 00000000..62cf1d68 --- /dev/null +++ b/internal/infrastructure/data_storage/uow.py @@ -0,0 +1,28 @@ +from internal.infrastructure.data_storage.flat import FlatContextMaker +from internal.infrastructure.data_storage.relational import RelationalContextMakerType +from internal.infrastructure.data_storage.universal_context import ( + UniversalDataStorageContext, +) +from internal.uow import AbstractUnitOfWork +from internal.uow.exception import NotActiveContextException + + +class UnitOfWork(AbstractUnitOfWork[UniversalDataStorageContext]): + + def __init__( + self, + postgres_context_maker: RelationalContextMakerType, + flat_context_maker: FlatContextMaker, + ): + self.postgres_context_maker = postgres_context_maker + self.flat_context_maker = flat_context_maker + self._context: UniversalDataStorageContext | None = None + + def __enter__(self) -> UniversalDataStorageContext: + db_session = self.postgres_context_maker() + file_storage = self.flat_context_maker() + self._context = UniversalDataStorageContext(db_session, file_storage) + if self._context is None: + raise NotActiveContextException() + else: + return self._context diff --git a/internal/repository/flat/file.py b/internal/repository/flat/file.py index b9d669d7..f51f9b94 100644 --- a/internal/repository/flat/file.py +++ b/internal/repository/flat/file.py @@ -9,7 +9,9 @@ CSVFileResponseSchema, ) from internal.dto.repository.file import File, FileCreateSchema -from internal.infrastructure.data_storage.flat import FlatContext +from internal.infrastructure.data_storage.universal_context import ( + UniversalDataStorageContext, +) CHUNK_SIZE = 1024 @@ -21,11 +23,11 @@ async def create( self, file: File, file_info: FileCreateSchema, - context: FlatContext, + context: UniversalDataStorageContext, ) -> None: path_to_file = Path.joinpath( - context.upload_directory_path, str(file_info.file_name) + context.flat_context.upload_directory_path, str(file_info.file_name) ) try: async with aiofiles.open(path_to_file, "wb") as out_file: # !!! @@ -37,10 +39,12 @@ async def create( def find( self, file_info: CSVFileFindSchema, - context: FlatContext, + context: UniversalDataStorageContext, ) -> CSVFileResponseSchema: - path_to_file = Path(context.upload_directory_path, str(file_info.file_name)) + path_to_file = Path( + context.flat_context.upload_directory_path, str(file_info.file_name) + ) return pd.read_csv( path_to_file, diff --git a/internal/repository/relational/crud.py b/internal/repository/relational/crud.py index 73fad976..504e4c11 100644 --- a/internal/repository/relational/crud.py +++ b/internal/repository/relational/crud.py @@ -9,8 +9,8 @@ BaseFindSchema, BaseResponseSchema, ) -from internal.infrastructure.data_storage.relational.context import ( - RelationalContextType, +from internal.infrastructure.data_storage.universal_context import ( + UniversalDataStorageContext, ) @@ -30,24 +30,26 @@ def __init__( self._response_schema: Type[ResponseSchema] = response_schema def create( - self, create_schema: CreateSchema, context: RelationalContextType + self, create_schema: CreateSchema, context: UniversalDataStorageContext ) -> ResponseSchema: create_schema_dict = create_schema.model_dump() db_model_instance = self._orm_model(**create_schema_dict) - context.add(db_model_instance) - context.flush() + context.relational_context.add(db_model_instance) + context.relational_context.flush() return self._response_schema.model_validate(db_model_instance) def _find( - self, find_schema: FindSchema, context: RelationalContextType + self, find_schema: FindSchema, context: UniversalDataStorageContext ) -> ORMModel | None: find_schema_dict = find_schema.model_dump() stmt = select(self._orm_model).filter_by(**find_schema_dict) - db_model_instance = context.execute(stmt).scalars().one_or_none() + db_model_instance = ( + context.relational_context.execute(stmt).scalars().one_or_none() + ) return db_model_instance def find( - self, find_schema: FindSchema, context: RelationalContextType + self, find_schema: FindSchema, context: UniversalDataStorageContext ) -> ResponseSchema | None: db_model_instance = self._find(find_schema, context) response = ( @@ -61,7 +63,7 @@ def find_or_create( self, find_schema: FindSchema, create_schema: CreateSchema, - context: RelationalContextType, + context: UniversalDataStorageContext, ) -> ResponseSchema: db_model_instance = self._find(find_schema, context) @@ -74,7 +76,7 @@ def update( find_schema: FindSchema, update_schema: UpdateSchema, fields_to_update_if_none: set[str] | None, - context: RelationalContextType, + context: UniversalDataStorageContext, ) -> ResponseSchema: db_model_instance = self._find(find_schema, context) @@ -87,17 +89,17 @@ def update( if value is not None or key in fields_to_update_if_none: setattr(db_model_instance, key, value) - context.add(db_model_instance) - context.flush() + context.relational_context.add(db_model_instance) + context.relational_context.flush() return self._response_schema.model_validate(db_model_instance) def delete( - self, find_schema: FindSchema, context: RelationalContextType + self, find_schema: FindSchema, context: UniversalDataStorageContext ) -> ResponseSchema | None: db_model_instance = self._find(find_schema, context) if not db_model_instance: return None - context.delete(db_model_instance) - context.flush() + context.relational_context.delete(db_model_instance) + context.relational_context.flush() return self._response_schema.model_validate(db_model_instance) diff --git a/internal/repository/relational/file/dataset.py b/internal/repository/relational/file/dataset.py index 8b7898e7..f04a2294 100644 --- a/internal/repository/relational/file/dataset.py +++ b/internal/repository/relational/file/dataset.py @@ -1,8 +1,8 @@ from sqlalchemy import select from sqlalchemy.orm import joinedload -from internal.infrastructure.data_storage.relational.context import ( - RelationalContextType, +from internal.infrastructure.data_storage.universal_context import ( + UniversalDataStorageContext, ) from internal.infrastructure.data_storage.relational.model.file import DatasetORM from internal.repository.relational import CRUD @@ -32,7 +32,7 @@ def __init__(self): def find_with_file_metadata( self, dataset_info: DatasetFindSchema, - context: RelationalContextType, + context: UniversalDataStorageContext, ) -> tuple[DatasetResponseSchema, FileMetadataResponseSchema]: dataset_find_dict = dataset_info.model_dump() @@ -41,7 +41,9 @@ def find_with_file_metadata( .options(joinedload(DatasetORM.file_metadata)) .filter_by(**dataset_find_dict) ) - dataset_orm_instance = context.execute(stmt).scalars().one_or_none() + dataset_orm_instance = ( + context.relational_context.execute(stmt).scalars().one_or_none() + ) if not dataset_orm_instance: raise DatasetNotFoundException() diff --git a/internal/rest/http/di.py b/internal/rest/http/di.py index 4c59417a..778ad596 100644 --- a/internal/rest/http/di.py +++ b/internal/rest/http/di.py @@ -1,6 +1,6 @@ from fastapi import Depends -from internal.infrastructure.data_storage.flat import FlatContextMaker +from internal.infrastructure.data_storage.flat import get_flat_context_maker from internal.infrastructure.data_storage.relational.postgres.context import ( get_postgres_context_maker, get_postgres_context_maker_without_pool, @@ -11,24 +11,23 @@ DatasetRepository, ) from internal.repository.relational.task import TaskRepository -from internal.uow import UnitOfWork +from internal.infrastructure.data_storage.uow import UnitOfWork -def get_unit_of_work(context_maker=Depends(get_postgres_context_maker)) -> UnitOfWork: +def get_unit_of_work( + postgres_context_maker=Depends(get_postgres_context_maker), + flat_context_maker=Depends(get_flat_context_maker), +) -> UnitOfWork: - return UnitOfWork(context_maker) + return UnitOfWork(postgres_context_maker, flat_context_maker) def get_unit_of_work_without_pool( - context_maker=Depends(get_postgres_context_maker_without_pool), + postgres_context_maker=Depends(get_postgres_context_maker_without_pool), + flat_context_maker=Depends(get_flat_context_maker), ) -> UnitOfWork: - return UnitOfWork(context_maker) - - -def get_flat_unit_of_work(context_maker: FlatContextMaker = Depends()) -> UnitOfWork: - - return UnitOfWork(context_maker) + return UnitOfWork(postgres_context_maker, flat_context_maker) def get_file_repo() -> FileRepository: diff --git a/internal/rest/http/file/di.py b/internal/rest/http/file/di.py index 6f25b502..77d40b21 100644 --- a/internal/rest/http/file/di.py +++ b/internal/rest/http/file/di.py @@ -5,9 +5,8 @@ get_file_repo, get_file_metadata_repo, get_dataset_repo, - get_flat_unit_of_work, ) -from internal.uow import UnitOfWork +from internal.uow import AbstractUnitOfWork from internal.usecase.file import SaveFile, SaveDataset, CheckContentType from internal.usecase.file.retrieve_dataset import RetrieveDataset from internal.usecase.file.save_dataset import DatasetRepo as SaveDatasetRepo @@ -16,21 +15,19 @@ def get_save_file_use_case( - unit_of_work: UnitOfWork = Depends(get_unit_of_work), - flat_unit_of_work: UnitOfWork = Depends(get_flat_unit_of_work), + unit_of_work: AbstractUnitOfWork = Depends(get_unit_of_work), file_repo: FileRepo = Depends(get_file_repo), file_metadata_repo: FileMetadataRepo = Depends(get_file_metadata_repo), ) -> SaveFile: return SaveFile( - file_info_unit_of_work=unit_of_work, - file_unit_of_work=flat_unit_of_work, + unit_of_work=unit_of_work, file_repo=file_repo, file_metadata_repo=file_metadata_repo, ) def get_save_dataset_use_case( - unit_of_work: UnitOfWork = Depends(get_unit_of_work), + unit_of_work: AbstractUnitOfWork = Depends(get_unit_of_work), dataset_repo: SaveDatasetRepo = Depends(get_dataset_repo), ) -> SaveDataset: return SaveDataset( @@ -44,7 +41,7 @@ def get_check_content_type_use_case() -> CheckContentType: def get_retrieve_dataset_use_case( - unit_of_work: UnitOfWork = Depends(get_unit_of_work), + unit_of_work: AbstractUnitOfWork = Depends(get_unit_of_work), dataset_repo: RetrieveDatasetRepo = Depends(get_dataset_repo), ) -> RetrieveDataset: return RetrieveDataset( diff --git a/internal/rest/http/task/di.py b/internal/rest/http/task/di.py index 0138f6ae..f8aa0539 100644 --- a/internal/rest/http/task/di.py +++ b/internal/rest/http/task/di.py @@ -1,6 +1,6 @@ from fastapi import Depends -from internal.uow import UnitOfWork +from internal.uow import AbstractUnitOfWork from internal.rest.http.di import get_unit_of_work, get_task_repo, get_dataset_repo from internal.usecase.task import RetrieveTask, SetTask from internal.usecase.task.retrieve_task import TaskRepo as RetrieveTaskRepo @@ -16,7 +16,7 @@ def get_profiling_task_worker() -> ProfilingTaskWorker: def get_retrieve_task_use_case( - unit_of_work: UnitOfWork = Depends(get_unit_of_work), + unit_of_work: AbstractUnitOfWork = Depends(get_unit_of_work), task_repo: RetrieveTaskRepo = Depends(get_task_repo), ) -> RetrieveTask: @@ -27,7 +27,7 @@ def get_retrieve_task_use_case( def get_set_task_use_case( - unit_of_work: UnitOfWork = Depends(get_unit_of_work), + unit_of_work: AbstractUnitOfWork = Depends(get_unit_of_work), task_repo: SetTaskRepo = Depends(get_task_repo), dataset_repo: SetDatasetRepo = Depends(get_dataset_repo), profiling_task_worker: ProfilingTaskWorker = Depends(get_profiling_task_worker), diff --git a/internal/uow/__init__.py b/internal/uow/__init__.py index 0ea6a6a2..7b4ffd24 100644 --- a/internal/uow/__init__.py +++ b/internal/uow/__init__.py @@ -1 +1,5 @@ -from internal.uow.uow import DataStorageContext, UnitOfWork # noqa: F401 +from internal.uow.uow import ( + DataStorageContext, + DataStorageContextMaker, + AbstractUnitOfWork, +) # noqa: F401 diff --git a/internal/uow/exception.py b/internal/uow/exception.py new file mode 100644 index 00000000..3cbd2e1e --- /dev/null +++ b/internal/uow/exception.py @@ -0,0 +1,4 @@ +class NotActiveContextException(Exception): + + def __init__(self): + super().__init__("The used context was closed or not activated") diff --git a/internal/uow/uow.py b/internal/uow/uow.py index 1a1e8656..6c28a8ac 100644 --- a/internal/uow/uow.py +++ b/internal/uow/uow.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Protocol, runtime_checkable @@ -6,8 +7,6 @@ class DataStorageContext(Protocol): def commit(self) -> None: ... - def flush(self) -> None: ... - def rollback(self) -> None: ... def close(self) -> None: ... @@ -18,15 +17,12 @@ class DataStorageContextMaker(Protocol): def __call__(self) -> DataStorageContext: ... -class UnitOfWork: +class AbstractUnitOfWork[C: DataStorageContext](ABC): - def __init__(self, context_maker: DataStorageContextMaker): - self._context_maker: DataStorageContextMaker = context_maker - self._context: DataStorageContext | None = None + _context: C | None - def __enter__(self) -> DataStorageContext: - self._context = self._context_maker() - return self._context + @abstractmethod + def __enter__(self) -> C: ... def __exit__(self, exc_type, exc_val, exc_tb) -> None: if self._context is not None: diff --git a/internal/usecase/file/retrieve_dataset.py b/internal/usecase/file/retrieve_dataset.py index b252864c..ea3f8bc8 100644 --- a/internal/usecase/file/retrieve_dataset.py +++ b/internal/usecase/file/retrieve_dataset.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from internal.dto.repository.file import DatasetFindSchema, DatasetResponseSchema -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork from internal.usecase.file.exception import DatasetNotFoundException diff --git a/internal/usecase/file/save_dataset.py b/internal/usecase/file/save_dataset.py index 476e57fc..fddbc61d 100644 --- a/internal/usecase/file/save_dataset.py +++ b/internal/usecase/file/save_dataset.py @@ -2,7 +2,7 @@ from uuid import UUID from internal.dto.repository.file import DatasetCreateSchema, DatasetResponseSchema -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork class DatasetRepo(Protocol): diff --git a/internal/usecase/file/save_file.py b/internal/usecase/file/save_file.py index 82c2cb2b..57ee3b33 100644 --- a/internal/usecase/file/save_file.py +++ b/internal/usecase/file/save_file.py @@ -13,7 +13,7 @@ FileMetadataCreateSchema, FileMetadataResponseSchema, ) -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork from internal.usecase.file.exception import FailedReadFileException @@ -44,17 +44,12 @@ class SaveFile: def __init__( self, - # It is assumed that the two repositories will be associated with different repositories. - # In order to support different repositories, different UoW will be needed. - # If both of your repositories are linked to the same repository, use only one of the UoW. - file_info_unit_of_work: UnitOfWork, - file_unit_of_work: UnitOfWork, + unit_of_work: UnitOfWork, file_repo: FileRepo, file_metadata_repo: FileMetadataRepo, ): - self.file_info_unit_of_work = file_info_unit_of_work - self.file_unit_of_work = file_unit_of_work + self.unit_of_work = unit_of_work self.file_repo = file_repo self.file_metadata_repo = file_metadata_repo @@ -68,17 +63,14 @@ async def __call__(self, *, upload_file: File) -> SaveFileUseCaseResult: mime_type=upload_file.content_type, ) - with self.file_unit_of_work as file_context: - with self.file_info_unit_of_work as file_info_context: - try: - response = self.file_metadata_repo.create( - file_metadata_create_schema, file_info_context - ) - await self.file_repo.create( - upload_file, create_file_schema, file_context - ) - except FailedFileReadingException as e: - raise FailedReadFileException(str(e)) + with self.unit_of_work as context: + try: + response = self.file_metadata_repo.create( + file_metadata_create_schema, context + ) + await self.file_repo.create(upload_file, create_file_schema, context) + except FailedFileReadingException as e: + raise FailedReadFileException(str(e)) return SaveFileUseCaseResult( id=response.id, diff --git a/internal/usecase/task/profile_task.py b/internal/usecase/task/profile_task.py index b42efdd3..e0479a30 100644 --- a/internal/usecase/task/profile_task.py +++ b/internal/usecase/task/profile_task.py @@ -19,7 +19,7 @@ from internal.usecase.file.exception import ( FileMetadataNotFoundException as FileMetadataNotFoundUseCaseException, ) -from internal.uow import UnitOfWork, DataStorageContext +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork class DatasetRepo(Protocol): @@ -40,40 +40,34 @@ class ProfileTask: def __init__( self, - # It is assumed that the two repositories will be associated with different repositories. - # In order to support different repositories, different UoW will be needed. - # If both of your repositories are linked to the same repository, use only one of the UoW. - file_unit_of_work: UnitOfWork, - dataset_unit_of_work: UnitOfWork, + unit_of_work: UnitOfWork, file_repo: FileRepo, dataset_repo: DatasetRepo, ): - self.file_unit_of_work = file_unit_of_work - self.dataset_unit_of_work = dataset_unit_of_work + self.unit_of_work = unit_of_work self.file_repo = file_repo self.dataset_repo = dataset_repo def __call__(self, *, dataset_id: UUID, config: OneOfTaskConfig) -> OneOfTaskResult: - with self.file_unit_of_work as file_context: - with self.dataset_unit_of_work as dataset_context: - try: - dataset, file_metadata = self.dataset_repo.find_with_file_metadata( - DatasetFindSchema(id=dataset_id), dataset_context - ) + with self.unit_of_work as context: + try: + dataset, file_metadata = self.dataset_repo.find_with_file_metadata( + DatasetFindSchema(id=dataset_id), context + ) - df = self.file_repo.find( - CSVFileFindSchema( - file_name=file_metadata.file_name, - separator=dataset.separator, - header=dataset.header, - ), - file_context, - ) - except DatasetNotFoundException: - raise DatasetNotFoundUseCaseException() - except FileMetadataNotFoundException: - raise FileMetadataNotFoundUseCaseException() + df = self.file_repo.find( + CSVFileFindSchema( + file_name=file_metadata.file_name, + separator=dataset.separator, + header=dataset.header, + ), + context, + ) + except DatasetNotFoundException: + raise DatasetNotFoundUseCaseException() + except FileMetadataNotFoundException: + raise FileMetadataNotFoundUseCaseException() task = match_task_by_primitive_name(primitive_name=config.primitive_name) result = task.execute(table=df, task_config=config) # type: ignore diff --git a/internal/usecase/task/retrieve_task.py b/internal/usecase/task/retrieve_task.py index 3317efd8..d873bf53 100644 --- a/internal/usecase/task/retrieve_task.py +++ b/internal/usecase/task/retrieve_task.py @@ -11,7 +11,7 @@ TaskFailureReason, ) from internal.dto.repository.task import TaskResponseSchema, TaskFindSchema -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork from internal.usecase.task.exception import TaskNotFoundException diff --git a/internal/usecase/task/set_task.py b/internal/usecase/task/set_task.py index a8c13800..bc86d82b 100644 --- a/internal/usecase/task/set_task.py +++ b/internal/usecase/task/set_task.py @@ -5,7 +5,7 @@ from internal.dto.repository.file import DatasetResponseSchema, DatasetFindSchema from internal.dto.repository.task import TaskCreateSchema, TaskResponseSchema from internal.dto.worker.task import ProfilingTaskCreateSchema -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork from internal.usecase.file.exception import DatasetNotFoundException diff --git a/internal/usecase/task/update_task_info.py b/internal/usecase/task/update_task_info.py index e8e9e10a..6cc220f8 100644 --- a/internal/usecase/task/update_task_info.py +++ b/internal/usecase/task/update_task_info.py @@ -2,7 +2,7 @@ from uuid import UUID -from internal.uow import DataStorageContext, UnitOfWork +from internal.uow import DataStorageContext, AbstractUnitOfWork as UnitOfWork from internal.domain.task.value_objects import TaskStatus, OneOfTaskResult from internal.dto.repository.task import ( TaskUpdateSchema, diff --git a/tests/conftest.py b/tests/conftest.py index 3c231387..093ca93a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,10 @@ import logging from internal.infrastructure.data_storage.relational.model import ORMBaseModel +from internal.infrastructure.data_storage.flat import FlatContextMaker +from internal.infrastructure.data_storage.universal_context import ( + UniversalDataStorageContext, +) from internal.infrastructure.data_storage import settings # https://stackoverflow.com/questions/61582142/test-pydantic-settings-in-fastapi @@ -28,6 +32,11 @@ def postgres_context_maker(): return sessionmaker(test_engine, expire_on_commit=False) +@pytest.fixture(scope="session") +def flat_context_maker(): + return FlatContextMaker() + + @pytest.fixture(scope="function") def postgres_context(postgres_context_maker): context = postgres_context_maker() @@ -37,6 +46,24 @@ def postgres_context(postgres_context_maker): context.close() +@pytest.fixture(scope="function") +def flat_context(flat_context_maker): + context = flat_context_maker() + + yield context + + context.close() + + +@pytest.fixture(scope="function") +def universal_context(postgres_context, flat_context): + context = UniversalDataStorageContext(postgres_context, flat_context) + + yield context + + # context.close() + + @pytest.fixture(autouse=True) def clean_tables(postgres_context): for table in reversed(ORMBaseModel.metadata.sorted_tables): diff --git a/tests/repository/flat/test_file.py b/tests/repository/flat/test_file.py index cf5f26c9..aa6e0762 100644 --- a/tests/repository/flat/test_file.py +++ b/tests/repository/flat/test_file.py @@ -22,6 +22,14 @@ def mock_flat_context(tmp_path, mocker: MockFixture): return context +@pytest.fixture +def mock_universal_context(mock_flat_context, mocker: MockFixture): + context = mocker.MagicMock(spec=FlatContext) + context.flat_context = mock_flat_context + context.relational_context = mocker.Mock() + return context + + @pytest.fixture def file_repository(): return FileRepository() @@ -29,7 +37,7 @@ def file_repository(): @pytest.mark.asyncio async def test_create_file_success( - mocker: MockFixture, file_repository, mock_flat_context + mocker: MockFixture, file_repository, mock_universal_context ): file_name = uuid4() file_content = b"Hello, World!" @@ -40,9 +48,11 @@ async def test_create_file_success( side_effect=[file_content, b""] ) # Читаем содержимое файла - await file_repository.create(mock_file, file_info, mock_flat_context) + await file_repository.create(mock_file, file_info, mock_universal_context) - created_file_path = mock_flat_context.upload_directory_path / str(file_name) + created_file_path = mock_universal_context.flat_context.upload_directory_path / str( + file_name + ) assert created_file_path.is_file() async with aiofiles.open(created_file_path, "rb") as f: @@ -50,17 +60,19 @@ async def test_create_file_success( assert content == file_content -def test_find_file_success(file_repository, mock_flat_context): +def test_find_file_success(file_repository, mock_universal_context): file_name = uuid4() file_content = "col1,col2\n1,2\n3,4" - file_path = mock_flat_context.upload_directory_path / str(file_name) + file_path = mock_universal_context.flat_context.upload_directory_path / str( + file_name + ) with open(file_path, "w") as f: f.write(file_content) file_info = CSVFileFindSchema(file_name=file_name, separator=",", header=[0]) - result = file_repository.find(file_info, mock_flat_context) + result = file_repository.find(file_info, mock_universal_context) expected_df = pd.DataFrame({"col1": [1, 3], "col2": [2, 4]}) pd.testing.assert_frame_equal(result, expected_df) @@ -68,7 +80,7 @@ def test_find_file_success(file_repository, mock_flat_context): @pytest.mark.asyncio async def test_create_file_failure( - mocker: MockFixture, file_repository, mock_flat_context + mocker: MockFixture, file_repository, mock_universal_context ): file_name = uuid4() file_info = FileCreateSchema(file_name=file_name) @@ -79,11 +91,11 @@ async def test_create_file_failure( with pytest.raises( FailedFileReadingException, match="The sent file could not be read." ): - await file_repository.create(mock_file, file_info, mock_flat_context) + await file_repository.create(mock_file, file_info, mock_universal_context) -def test_find_file_failure(file_repository, mock_flat_context): +def test_find_file_failure(file_repository, mock_universal_context): file_info = CSVFileFindSchema(file_name=uuid4(), separator=",", header=[0]) with pytest.raises(FileNotFoundError): - file_repository.find(file_info, mock_flat_context) + file_repository.find(file_info, mock_universal_context) diff --git a/tests/repository/postgres/test_dataset.py b/tests/repository/postgres/test_dataset.py index c9abd4a6..768a0f62 100644 --- a/tests/repository/postgres/test_dataset.py +++ b/tests/repository/postgres/test_dataset.py @@ -23,9 +23,9 @@ def file_create_schema(): @pytest.fixture -def file_id(file_create_schema, postgres_context): +def file_id(file_create_schema, universal_context): file_metadata_repo = FileMetadataRepository() - response = file_metadata_repo.create(file_create_schema, postgres_context) + response = file_metadata_repo.create(file_create_schema, universal_context) return response.id @@ -45,8 +45,8 @@ def update_schema(): class TestDatasetRepository: - def test_create(self, repo, create_schema, postgres_context): - response = repo.create(create_schema, postgres_context) + def test_create(self, repo, create_schema, universal_context): + response = repo.create(create_schema, universal_context) assert response is not None assert response.file_id == create_schema.file_id @@ -57,40 +57,40 @@ def test_create_and_find( self, repo, create_schema, - postgres_context, + universal_context, ): - empty_response = repo.find(DatasetFindSchema(id=uuid4()), postgres_context) + empty_response = repo.find(DatasetFindSchema(id=uuid4()), universal_context) assert empty_response is None - created_response = repo.create(create_schema, postgres_context) + created_response = repo.create(create_schema, universal_context) response = repo.find( - DatasetFindSchema(id=created_response.id), postgres_context + DatasetFindSchema(id=created_response.id), universal_context ) assert response is not None assert response.file_id == create_schema.file_id assert response.separator == create_schema.separator assert response.header == create_schema.header - def test_update(self, repo, create_schema, update_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_update(self, repo, create_schema, update_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = DatasetFindSchema(id=created_response.id) - repo.update(find_schema, update_schema, None, postgres_context) + repo.update(find_schema, update_schema, None, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None assert response.file_id == create_schema.file_id assert response.separator == update_schema.separator assert response.header == update_schema.header - def test_delete(self, repo, create_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_delete(self, repo, create_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = DatasetFindSchema(id=created_response.id) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None - repo.delete(find_schema, postgres_context) + repo.delete(find_schema, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is None diff --git a/tests/repository/postgres/test_file_metadata.py b/tests/repository/postgres/test_file_metadata.py index 91dfd13d..47d6963c 100644 --- a/tests/repository/postgres/test_file_metadata.py +++ b/tests/repository/postgres/test_file_metadata.py @@ -31,8 +31,8 @@ def update_schema(): class TestFileMetadataRepository: - def test_create(self, repo, create_schema, postgres_context): - response = repo.create(create_schema, postgres_context) + def test_create(self, repo, create_schema, universal_context): + response = repo.create(create_schema, universal_context) assert response is not None assert response.file_name == create_schema.file_name @@ -43,40 +43,42 @@ def test_create_and_find( self, repo, create_schema, - postgres_context, + universal_context, ): - empty_response = repo.find(FileMetadataFindSchema(id=uuid4()), postgres_context) + empty_response = repo.find( + FileMetadataFindSchema(id=uuid4()), universal_context + ) assert empty_response is None - created_response = repo.create(create_schema, postgres_context) + created_response = repo.create(create_schema, universal_context) response = repo.find( - FileMetadataFindSchema(id=created_response.id), postgres_context + FileMetadataFindSchema(id=created_response.id), universal_context ) assert response is not None assert response.file_name == create_schema.file_name assert response.original_file_name == create_schema.original_file_name assert response.mime_type == create_schema.mime_type - def test_update(self, repo, create_schema, update_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_update(self, repo, create_schema, update_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = FileMetadataFindSchema(id=created_response.id) - repo.update(find_schema, update_schema, None, postgres_context) + repo.update(find_schema, update_schema, None, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None assert response.file_name == create_schema.file_name assert response.original_file_name == update_schema.original_file_name assert response.mime_type == create_schema.mime_type - def test_delete(self, repo, create_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_delete(self, repo, create_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = FileMetadataFindSchema(id=created_response.id) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None - repo.delete(find_schema, postgres_context) + repo.delete(find_schema, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is None diff --git a/tests/repository/postgres/test_task.py b/tests/repository/postgres/test_task.py index 0ac36a21..183ba638 100644 --- a/tests/repository/postgres/test_task.py +++ b/tests/repository/postgres/test_task.py @@ -30,9 +30,9 @@ def file_create_schema(): @pytest.fixture -def file_id(file_create_schema, postgres_context): +def file_id(file_create_schema, universal_context): file_metadata_repo = FileMetadataRepository() - response = file_metadata_repo.create(file_create_schema, postgres_context) + response = file_metadata_repo.create(file_create_schema, universal_context) return response.id @@ -42,9 +42,9 @@ def dataset_create_schema(file_id): @pytest.fixture -def dataset_id(dataset_create_schema, postgres_context): +def dataset_id(dataset_create_schema, universal_context): dataset_repo = DatasetRepository() - response = dataset_repo.create(dataset_create_schema, postgres_context) + response = dataset_repo.create(dataset_create_schema, universal_context) return response.id @@ -76,8 +76,8 @@ def update_schema(): class TestDatasetRepository: - def test_create(self, repo, create_schema, postgres_context): - response = repo.create(create_schema, postgres_context) + def test_create(self, repo, create_schema, universal_context): + response = repo.create(create_schema, universal_context) assert response is not None assert response.dataset_id == create_schema.dataset_id @@ -92,39 +92,39 @@ def test_create_and_find( self, repo, create_schema, - postgres_context, + universal_context, ): - empty_response = repo.find(TaskFindSchema(id=uuid4()), postgres_context) + empty_response = repo.find(TaskFindSchema(id=uuid4()), universal_context) assert empty_response is None - created_response = repo.create(create_schema, postgres_context) - response = repo.find(TaskFindSchema(id=created_response.id), postgres_context) + created_response = repo.create(create_schema, universal_context) + response = repo.find(TaskFindSchema(id=created_response.id), universal_context) assert response is not None assert response.dataset_id == create_schema.dataset_id assert response.status == create_schema.status assert response.config == create_schema.config - def test_update(self, repo, create_schema, update_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_update(self, repo, create_schema, update_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = TaskFindSchema(id=created_response.id) - repo.update(find_schema, update_schema, None, postgres_context) + repo.update(find_schema, update_schema, None, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None assert response.dataset_id == create_schema.dataset_id assert response.status == update_schema.status assert response.config == create_schema.config assert response.failure_reason == update_schema.failure_reason - def test_delete(self, repo, create_schema, postgres_context): - created_response = repo.create(create_schema, postgres_context) + def test_delete(self, repo, create_schema, universal_context): + created_response = repo.create(create_schema, universal_context) find_schema = TaskFindSchema(id=created_response.id) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is not None - repo.delete(find_schema, postgres_context) + repo.delete(find_schema, universal_context) - response = repo.find(find_schema, postgres_context) + response = repo.find(find_schema, universal_context) assert response is None diff --git a/tests/rest b/tests/rest new file mode 100644 index 00000000..e69de29b diff --git a/tests/uow/test_unit_of_work.py b/tests/uow/test_unit_of_work.py index 542172af..49b81111 100644 --- a/tests/uow/test_unit_of_work.py +++ b/tests/uow/test_unit_of_work.py @@ -1,40 +1,68 @@ import pytest from pytest_mock import MockerFixture -from internal.uow import DataStorageContext, UnitOfWork -from internal.uow.uow import DataStorageContextMaker +from internal.uow import DataStorageContext, DataStorageContextMaker +from internal.infrastructure.data_storage.uow import UnitOfWork @pytest.fixture -def context_mock(mocker: MockerFixture): +def postgres_context_mock(mocker: MockerFixture): return mocker.Mock(spec=DataStorageContext) @pytest.fixture -def context_maker_mock(mocker: MockerFixture, context_mock): - return mocker.Mock(spec=DataStorageContextMaker, return_value=context_mock) +def postgres_context_maker_mock(mocker: MockerFixture, postgres_context_mock): + return mocker.Mock(spec=DataStorageContextMaker, return_value=postgres_context_mock) -def test_unit_of_work_commit_on_success(context_maker_mock, context_mock) -> None: - uow = UnitOfWork(context_maker_mock) +@pytest.fixture +def flat_context_mock(mocker: MockerFixture): + return mocker.Mock(spec=DataStorageContext) + + +@pytest.fixture +def flat_context_maker_mock(mocker: MockerFixture, flat_context_mock): + return mocker.Mock(spec=DataStorageContextMaker, return_value=flat_context_mock) + + +def test_unit_of_work_commit_on_success( + postgres_context_mock, + postgres_context_maker_mock, + flat_context_mock, + flat_context_maker_mock, +) -> None: + uow = UnitOfWork(postgres_context_maker_mock, flat_context_maker_mock) with uow as context: assert isinstance(context, DataStorageContext) pass - context_mock.commit.assert_called_once() - context_mock.rollback.assert_not_called() - context_mock.close.assert_called_once() + postgres_context_mock.commit.assert_called_once() + postgres_context_mock.rollback.assert_not_called() + postgres_context_mock.close.assert_called_once() + flat_context_mock.commit.assert_called_once() + flat_context_mock.rollback.assert_not_called() + flat_context_mock.close.assert_called_once() -def test_unit_of_work_rollback_on_failure(context_maker_mock, context_mock) -> None: - uow = UnitOfWork(context_maker_mock) + +def test_unit_of_work_rollback_on_failure( + postgres_context_mock, + postgres_context_maker_mock, + flat_context_mock, + flat_context_maker_mock, +) -> None: + uow = UnitOfWork(postgres_context_maker_mock, flat_context_maker_mock) with pytest.raises(ValueError): with uow as context: assert isinstance(context, DataStorageContext) raise ValueError("Test error") - context_mock.commit.assert_not_called() - context_mock.rollback.assert_called_once() - context_mock.close.assert_called_once() + postgres_context_mock.commit.assert_not_called() + postgres_context_mock.rollback.assert_called_once() + postgres_context_mock.close.assert_called_once() + + flat_context_mock.commit.assert_not_called() + flat_context_mock.rollback.assert_called_once() + flat_context_mock.close.assert_called_once() diff --git a/tests/usecase/test_profile_task.py b/tests/usecase/test_profile_task.py index 99ef932a..b9f5f43b 100644 --- a/tests/usecase/test_profile_task.py +++ b/tests/usecase/test_profile_task.py @@ -61,8 +61,7 @@ def profile_task_use_case( file_repo_mock, ) -> ProfileTask: return ProfileTask( - file_unit_of_work=unit_of_work_mock, - dataset_unit_of_work=unit_of_work_mock, + unit_of_work=unit_of_work_mock, dataset_repo=dataset_repo_mock, file_repo=file_repo_mock, ) @@ -145,8 +144,8 @@ def test_profile_task_use_case_success( ) # Check that UnitOfWork was entered and exited correctly - assert unit_of_work_mock.__enter__.call_count == 2 - assert unit_of_work_mock.__exit__.call_count == 2 + unit_of_work_mock.__enter__.assert_called_once() + unit_of_work_mock.__exit__.assert_called_once() @pytest.mark.parametrize( @@ -186,5 +185,5 @@ def test_profile_task_use_case_dataset_not_found( assert not file_repo_mock.find.called # Check that UnitOfWork was entered and exited correctly - assert unit_of_work_mock.__enter__.call_count == 2 - assert unit_of_work_mock.__exit__.call_count == 2 + unit_of_work_mock.__enter__.assert_called_once() + unit_of_work_mock.__exit__.assert_called_once() diff --git a/tests/usecase/test_save_file.py b/tests/usecase/test_save_file.py index 11cba635..7b6eb97c 100644 --- a/tests/usecase/test_save_file.py +++ b/tests/usecase/test_save_file.py @@ -71,8 +71,7 @@ def save_file( "internal.usecase.file.save_file.FileEntity", return_value=file_entity_mock ) return SaveFile( - file_unit_of_work=unit_of_work_mock, - file_info_unit_of_work=unit_of_work_mock, + unit_of_work=unit_of_work_mock, file_repo=file_repo_mock, file_metadata_repo=file_metadata_repo_mock, ) @@ -132,8 +131,8 @@ async def test_save_file( ) # Verify that UnitOfWork was used correctly - assert unit_of_work_mock.__enter__.call_count == 2 - assert unit_of_work_mock.__exit__.call_count == 2 + unit_of_work_mock.__enter__.assert_called_once() + unit_of_work_mock.__exit__.assert_called_once() # Verify that the result matches the expected SaveFileUseCaseResult assert result == SaveFileUseCaseResult( @@ -173,5 +172,5 @@ async def test_save_file_failed_read_file_exception( file_repo_mock.create.assert_called_once() # Verify that UnitOfWork was used correctly - assert unit_of_work_mock.__enter__.call_count == 2 - assert unit_of_work_mock.__exit__.call_count == 2 + unit_of_work_mock.__enter__.assert_called_once() + unit_of_work_mock.__exit__.assert_called_once()