Skip to content

Commit

Permalink
feat: generalize unit of work to ny type of repo
Browse files Browse the repository at this point in the history
  • Loading branch information
raf-nr committed Sep 28, 2024
1 parent e5e1c21 commit 3439ac5
Show file tree
Hide file tree
Showing 33 changed files with 376 additions and 214 deletions.
20 changes: 12 additions & 8 deletions internal/infrastructure/background_task/celery/task/di.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from internal.repository.flat import FileRepository
from internal.repository.relational.file import DatasetRepository
from internal.repository.relational.task import TaskRepository
from internal.uow import UnitOfWork
from internal.infrastructure.data_storage.uow import UnitOfWork
from internal.usecase.task.profile_task import ProfileTask
from internal.usecase.task.update_task_info import UpdateTaskInfo

Expand All @@ -23,9 +23,13 @@ def get_task_repo() -> TaskRepository:


def get_update_task_info_use_case():
context_maker = get_postgres_context_maker_without_pool()
postgres_context_maker = get_postgres_context_maker_without_pool()
flat_context_maker = get_flat_context_maker()

unit_of_work = UnitOfWork(context_maker)
unit_of_work = UnitOfWork(
postgres_context_maker=postgres_context_maker,
flat_context_maker=flat_context_maker,
)
task_repo = get_task_repo()

return UpdateTaskInfo(
Expand All @@ -37,15 +41,15 @@ def get_update_task_info_use_case():
def get_profile_task_use_case():
postgres_context_maker = get_postgres_context_maker_without_pool()
flat_context_maker = get_flat_context_maker()

file_unit_of_work = UnitOfWork(flat_context_maker)
dataset_unit_of_work = UnitOfWork(postgres_context_maker)
unit_of_work = UnitOfWork(
postgres_context_maker=postgres_context_maker,
flat_context_maker=flat_context_maker,
)
file_repo = get_file_repo()
dataset_repo = get_dataset_repo()

return ProfileTask(
file_unit_of_work=file_unit_of_work,
dataset_unit_of_work=dataset_unit_of_work,
unit_of_work=unit_of_work,
file_repo=file_repo, # type: ignore
dataset_repo=dataset_repo, # type: ignore
)
2 changes: 2 additions & 0 deletions internal/infrastructure/data_storage/flat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
FlatContext,
get_flat_context_maker,
)

from internal.infrastructure.data_storage.uow import UnitOfWork # noqa: F401
35 changes: 27 additions & 8 deletions internal/infrastructure/data_storage/flat/context.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,44 @@
import os
from pathlib import Path

from internal.infrastructure.data_storage import settings
from internal.uow.exception import NotActiveContextException


class FlatContext:

def __init__(self, upload_directory_path: Path):
self._upload_directory_path = upload_directory_path
self._change: list[Path] | None = []

@property
def upload_directory_path(self) -> Path:
return self._upload_directory_path

# This context implementation does not support transactions
def flush(self) -> None: ...

def rollback(self) -> None: ...

def commit(self) -> None: ...

def close(self) -> None: ... # TODO: implement flat context closing.
def add(self, file_path: Path) -> None:
if self._change is None:
raise NotActiveContextException()
else:
self._change.append(file_path)

def rollback(self) -> None:
if self._change is None:
raise NotActiveContextException()
for file_path in self._change:
if os.path.exists(file_path):
os.remove(file_path)
self._change.clear()

def commit(self) -> None:
# This context implementation does not this method.
# Changes are saved automatically.
# However, in case of rollback they will be deleted.
pass

def close(self) -> None:
if self._change:
self.rollback()
self._change = None


class FlatContextMaker:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from internal.infrastructure.data_storage.relational.context import ( # noqa: F401
RelationalContextType,
RelationalContextMakerType,
)
6 changes: 4 additions & 2 deletions internal/infrastructure/data_storage/relational/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, sessionmaker

RelationalContextType = Session
type RelationalContextType = Session

type RelationalContextMakerType = sessionmaker
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
PostgresContextMakerWithoutPool = sessionmaker(bind=engine_without_pool)


def get_postgres_context_maker() -> sessionmaker[Session]:
def get_postgres_context_maker() -> sessionmaker:
return PostgresContextMaker


def get_postgres_context_maker_without_pool() -> sessionmaker[Session]:
def get_postgres_context_maker_without_pool() -> sessionmaker:
return PostgresContextMakerWithoutPool
45 changes: 45 additions & 0 deletions internal/infrastructure/data_storage/universal_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from internal.infrastructure.data_storage.flat import FlatContext, FlatContextMaker
from internal.infrastructure.data_storage.relational import (
RelationalContextType,
RelationalContextMakerType,
)


class UniversalDataStorageContext:

def __init__(
self, relational_context: RelationalContextType, flat_context: FlatContext
):
self._relational_context = relational_context
self._flat_context = flat_context

@property
def relational_context(self) -> RelationalContextType:
return self._relational_context

@property
def flat_context(self) -> FlatContext:
return self._flat_context

def commit(self) -> None:
self._relational_context.commit()
self._flat_context.commit()

def rollback(self) -> None:
self._relational_context.rollback()
self._flat_context.rollback()

def close(self) -> None:
self._relational_context.close()
self._flat_context.close()


class UniversalDataStorageContextMaker:

def __init__(
self,
relational_context_maker: RelationalContextMakerType,
flat_context_maker: FlatContextMaker,
):
self.relational_context_maker = relational_context_maker
self.flat_context_maker = flat_context_maker
28 changes: 28 additions & 0 deletions internal/infrastructure/data_storage/uow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from internal.infrastructure.data_storage.flat import FlatContextMaker
from internal.infrastructure.data_storage.relational import RelationalContextMakerType
from internal.infrastructure.data_storage.universal_context import (
UniversalDataStorageContext,
)
from internal.uow import AbstractUnitOfWork
from internal.uow.exception import NotActiveContextException


class UnitOfWork(AbstractUnitOfWork[UniversalDataStorageContext]):

def __init__(
self,
postgres_context_maker: RelationalContextMakerType,
flat_context_maker: FlatContextMaker,
):
self.postgres_context_maker = postgres_context_maker
self.flat_context_maker = flat_context_maker
self._context: UniversalDataStorageContext | None = None

def __enter__(self) -> UniversalDataStorageContext:
db_session = self.postgres_context_maker()
file_storage = self.flat_context_maker()
self._context = UniversalDataStorageContext(db_session, file_storage)
if self._context is None:
raise NotActiveContextException()
else:
return self._context
14 changes: 9 additions & 5 deletions internal/repository/flat/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
CSVFileResponseSchema,
)
from internal.dto.repository.file import File, FileCreateSchema
from internal.infrastructure.data_storage.flat import FlatContext
from internal.infrastructure.data_storage.universal_context import (
UniversalDataStorageContext,
)

CHUNK_SIZE = 1024

Expand All @@ -21,11 +23,11 @@ async def create(
self,
file: File,
file_info: FileCreateSchema,
context: FlatContext,
context: UniversalDataStorageContext,
) -> None:

path_to_file = Path.joinpath(
context.upload_directory_path, str(file_info.file_name)
context.flat_context.upload_directory_path, str(file_info.file_name)
)
try:
async with aiofiles.open(path_to_file, "wb") as out_file: # !!!
Expand All @@ -37,10 +39,12 @@ async def create(
def find(
self,
file_info: CSVFileFindSchema,
context: FlatContext,
context: UniversalDataStorageContext,
) -> CSVFileResponseSchema:

path_to_file = Path(context.upload_directory_path, str(file_info.file_name))
path_to_file = Path(
context.flat_context.upload_directory_path, str(file_info.file_name)
)

return pd.read_csv(
path_to_file,
Expand Down
32 changes: 17 additions & 15 deletions internal/repository/relational/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
BaseFindSchema,
BaseResponseSchema,
)
from internal.infrastructure.data_storage.relational.context import (
RelationalContextType,
from internal.infrastructure.data_storage.universal_context import (
UniversalDataStorageContext,
)


Expand All @@ -30,24 +30,26 @@ def __init__(
self._response_schema: Type[ResponseSchema] = response_schema

def create(
self, create_schema: CreateSchema, context: RelationalContextType
self, create_schema: CreateSchema, context: UniversalDataStorageContext
) -> ResponseSchema:
create_schema_dict = create_schema.model_dump()
db_model_instance = self._orm_model(**create_schema_dict)
context.add(db_model_instance)
context.flush()
context.relational_context.add(db_model_instance)
context.relational_context.flush()
return self._response_schema.model_validate(db_model_instance)

def _find(
self, find_schema: FindSchema, context: RelationalContextType
self, find_schema: FindSchema, context: UniversalDataStorageContext
) -> ORMModel | None:
find_schema_dict = find_schema.model_dump()
stmt = select(self._orm_model).filter_by(**find_schema_dict)
db_model_instance = context.execute(stmt).scalars().one_or_none()
db_model_instance = (
context.relational_context.execute(stmt).scalars().one_or_none()
)
return db_model_instance

def find(
self, find_schema: FindSchema, context: RelationalContextType
self, find_schema: FindSchema, context: UniversalDataStorageContext
) -> ResponseSchema | None:
db_model_instance = self._find(find_schema, context)
response = (
Expand All @@ -61,7 +63,7 @@ def find_or_create(
self,
find_schema: FindSchema,
create_schema: CreateSchema,
context: RelationalContextType,
context: UniversalDataStorageContext,
) -> ResponseSchema:

db_model_instance = self._find(find_schema, context)
Expand All @@ -74,7 +76,7 @@ def update(
find_schema: FindSchema,
update_schema: UpdateSchema,
fields_to_update_if_none: set[str] | None,
context: RelationalContextType,
context: UniversalDataStorageContext,
) -> ResponseSchema:

db_model_instance = self._find(find_schema, context)
Expand All @@ -87,17 +89,17 @@ def update(
if value is not None or key in fields_to_update_if_none:
setattr(db_model_instance, key, value)

context.add(db_model_instance)
context.flush()
context.relational_context.add(db_model_instance)
context.relational_context.flush()

return self._response_schema.model_validate(db_model_instance)

def delete(
self, find_schema: FindSchema, context: RelationalContextType
self, find_schema: FindSchema, context: UniversalDataStorageContext
) -> ResponseSchema | None:
db_model_instance = self._find(find_schema, context)
if not db_model_instance:
return None
context.delete(db_model_instance)
context.flush()
context.relational_context.delete(db_model_instance)
context.relational_context.flush()
return self._response_schema.model_validate(db_model_instance)
10 changes: 6 additions & 4 deletions internal/repository/relational/file/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from sqlalchemy import select
from sqlalchemy.orm import joinedload

from internal.infrastructure.data_storage.relational.context import (
RelationalContextType,
from internal.infrastructure.data_storage.universal_context import (
UniversalDataStorageContext,
)
from internal.infrastructure.data_storage.relational.model.file import DatasetORM
from internal.repository.relational import CRUD
Expand Down Expand Up @@ -32,7 +32,7 @@ def __init__(self):
def find_with_file_metadata(
self,
dataset_info: DatasetFindSchema,
context: RelationalContextType,
context: UniversalDataStorageContext,
) -> tuple[DatasetResponseSchema, FileMetadataResponseSchema]:

dataset_find_dict = dataset_info.model_dump()
Expand All @@ -41,7 +41,9 @@ def find_with_file_metadata(
.options(joinedload(DatasetORM.file_metadata))
.filter_by(**dataset_find_dict)
)
dataset_orm_instance = context.execute(stmt).scalars().one_or_none()
dataset_orm_instance = (
context.relational_context.execute(stmt).scalars().one_or_none()
)

if not dataset_orm_instance:
raise DatasetNotFoundException()
Expand Down
Loading

0 comments on commit 3439ac5

Please sign in to comment.