From 7b44e9d7de453759860d28f1781764e069f25eec Mon Sep 17 00:00:00 2001 From: Goradi Date: Tue, 19 Nov 2024 17:18:23 +0500 Subject: [PATCH 1/2] tests: to upcoming changes now fails --- tests/test_sa2/__init__.py | 5 ++ tests/test_sa2/test_view_sync.py | 67 +++++++++++++++++++++++++++ tests/test_sa2/test_views_async.py | 74 ++++++++++++++++++++++++++++++ 3 files changed, 146 insertions(+) create mode 100644 tests/test_sa2/__init__.py create mode 100644 tests/test_sa2/test_view_sync.py create mode 100644 tests/test_sa2/test_views_async.py diff --git a/tests/test_sa2/__init__.py b/tests/test_sa2/__init__.py new file mode 100644 index 00000000..9e754a1c --- /dev/null +++ b/tests/test_sa2/__init__.py @@ -0,0 +1,5 @@ +import pytest +from sqlalchemy import __version__ as __sa_version__ + +if __sa_version__.startswith("1."): + pytest.skip("SQLAlchemy 1.4 does not support this api", allow_module_level=True) diff --git a/tests/test_sa2/test_view_sync.py b/tests/test_sa2/test_view_sync.py new file mode 100644 index 00000000..5327c46a --- /dev/null +++ b/tests/test_sa2/test_view_sync.py @@ -0,0 +1,67 @@ +from typing import Any, Generator + +import pytest +from sqlalchemy import ( + Integer, + String, + func, + select, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + MappedAsDataclass, + mapped_column, + sessionmaker, +) +from starlette.applications import Starlette +from starlette.testclient import TestClient + +from sqladmin import Admin, ModelView +from tests.common import sync_engine + +session_maker = sessionmaker(bind=sync_engine) +pytestmark = pytest.mark.anyio + + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) + name: Mapped[str] = mapped_column(String(length=16), init=True) + email: Mapped[str] = mapped_column(String, unique=True) + + +class UserAdmin(ModelView, model=User): + column_list = ["name", "email"] + column_labels = {"name": "Name", "email": "Email"} + + +app = Starlette() +admin = Admin(app=app, engine=sync_engine) +admin.add_model_view(UserAdmin) + + +@pytest.fixture +def prepare_database() -> Generator[None, None, None]: + Base.metadata.create_all(sync_engine) + yield + Base.metadata.drop_all(sync_engine) + + +@pytest.fixture +def client(prepare_database: Any) -> Generator[TestClient, None, None]: + with TestClient(app=app, base_url="http://testserver") as c: + yield c + + +def test_sync_create_dataclass(client: TestClient) -> None: + client.post("/admin/user/create", data={"name": "foo", "email": "bar"}) + stmt = select(func.count(User.id)) + with session_maker() as s: + result = s.execute(stmt) + assert result.scalar_one() == 1 diff --git a/tests/test_sa2/test_views_async.py b/tests/test_sa2/test_views_async.py new file mode 100644 index 00000000..d2c609b9 --- /dev/null +++ b/tests/test_sa2/test_views_async.py @@ -0,0 +1,74 @@ +from typing import Any, AsyncGenerator + +import pytest +from httpx import AsyncClient +from sqlalchemy import Integer, String, func, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + MappedAsDataclass, + mapped_column, +) +from starlette.applications import Starlette + +from sqladmin import Admin +from sqladmin.models import ModelView +from tests.common import async_engine + +async_session_maker = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, +) + +pytestmark = pytest.mark.anyio + + +class Base(MappedAsDataclass, DeclarativeBase): + pass + + +class User(Base): + __tablename__ = "users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) + name: Mapped[str] = mapped_column(String(length=16), init=True) + email: Mapped[str] = mapped_column(String, unique=True) + + +class UserAdmin(ModelView, model=User): + column_list = ["name", "email"] + column_labels = {"name": "Name", "email": "Email"} + + +app = Starlette() +async_admin = Admin(app=app, engine=async_engine) +async_admin.add_view(UserAdmin) + + +@pytest.fixture +async def async_prepare_database() -> AsyncGenerator[None, None]: + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await async_engine.dispose() + + +@pytest.fixture +async def async_client( + async_prepare_database: Any, +) -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient(app=app, base_url="http://testserver") as c: + yield c + + +async def test_async_create_dataclass(async_client: AsyncClient) -> None: + await async_client.post("/admin/user/create", data={"name": "foo", "email": "bar"}) + stmt = select(func.count(User.id)) + async with async_session_maker() as s: + result = await s.execute(stmt) + assert result.scalar_one() == 1 From cc97280a9521f0d6763fbc11defb025d5b6025ba Mon Sep 17 00:00:00 2001 From: Goradi Date: Tue, 19 Nov 2024 17:18:33 +0500 Subject: [PATCH 2/2] fix: insert dataclass support --- sqladmin/_queries.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/sqladmin/_queries.py b/sqladmin/_queries.py index 44ae2b9e..fcd596cf 100644 --- a/sqladmin/_queries.py +++ b/sqladmin/_queries.py @@ -189,8 +189,20 @@ async def _delete_async(self, pk: str, request: Request) -> None: await session.commit() await self.model_view.after_model_delete(obj, request) + def _prepare_insert_dataclass(self, data: dict[str, Any]) -> dict[str, Any]: + try: + init = { + k: v + for k, v in data.items() + if self.model_view.model.__dataclass_fields__[k].init # type: ignore[attr-defined] # caught in except block + } + except AttributeError: + return {} + else: + return init + def _insert_sync(self, data: dict[str, Any], request: Request) -> Any: - obj = self.model_view.model() + obj = self.model_view.model(**self._prepare_insert_dataclass(data)) with self.model_view.session_maker(expire_on_commit=False) as session: anyio.from_thread.run( @@ -205,7 +217,7 @@ def _insert_sync(self, data: dict[str, Any], request: Request) -> Any: return obj async def _insert_async(self, data: dict[str, Any], request: Request) -> Any: - obj = self.model_view.model() + obj = self.model_view.model(**self._prepare_insert_dataclass(data)) async with self.model_view.session_maker(expire_on_commit=False) as session: await self.model_view.on_model_change(data, obj, True, request)