Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MappedAsDataclass #857

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions sqladmin/_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,20 @@ async def _delete_async(self, pk: str, request: Request) -> None:
await session.commit()
await self.model_view.after_model_delete(obj, request)

def _prepare_insert_dataclass(self, data: dict[str, Any]) -> dict[str, Any]:
try:
init = {
k: v
for k, v in data.items()
if self.model_view.model.__dataclass_fields__[k].init # type: ignore[attr-defined] # caught in except block
}
except AttributeError:
return {}
else:
return init

def _insert_sync(self, data: dict[str, Any], request: Request) -> Any:
obj = self.model_view.model()
obj = self.model_view.model(**self._prepare_insert_dataclass(data))

with self.model_view.session_maker(expire_on_commit=False) as session:
anyio.from_thread.run(
Expand All @@ -205,7 +217,7 @@ def _insert_sync(self, data: dict[str, Any], request: Request) -> Any:
return obj

async def _insert_async(self, data: dict[str, Any], request: Request) -> Any:
obj = self.model_view.model()
obj = self.model_view.model(**self._prepare_insert_dataclass(data))

async with self.model_view.session_maker(expire_on_commit=False) as session:
await self.model_view.on_model_change(data, obj, True, request)
Expand Down
5 changes: 5 additions & 0 deletions tests/test_sa2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pytest
from sqlalchemy import __version__ as __sa_version__

if __sa_version__.startswith("1."):
pytest.skip("SQLAlchemy 1.4 does not support this api", allow_module_level=True)
67 changes: 67 additions & 0 deletions tests/test_sa2/test_view_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, Generator

import pytest
from sqlalchemy import (
Integer,
String,
func,
select,
)
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
sessionmaker,
)
from starlette.applications import Starlette
from starlette.testclient import TestClient

from sqladmin import Admin, ModelView
from tests.common import sync_engine

session_maker = sessionmaker(bind=sync_engine)
pytestmark = pytest.mark.anyio


class Base(MappedAsDataclass, DeclarativeBase):
pass


class User(Base):
__tablename__ = "users"

id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
name: Mapped[str] = mapped_column(String(length=16), init=True)
email: Mapped[str] = mapped_column(String, unique=True)


class UserAdmin(ModelView, model=User):
column_list = ["name", "email"]
column_labels = {"name": "Name", "email": "Email"}


app = Starlette()
admin = Admin(app=app, engine=sync_engine)
admin.add_model_view(UserAdmin)


@pytest.fixture
def prepare_database() -> Generator[None, None, None]:
Base.metadata.create_all(sync_engine)
yield
Base.metadata.drop_all(sync_engine)


@pytest.fixture
def client(prepare_database: Any) -> Generator[TestClient, None, None]:
with TestClient(app=app, base_url="http://testserver") as c:
yield c


def test_sync_create_dataclass(client: TestClient) -> None:
client.post("/admin/user/create", data={"name": "foo", "email": "bar"})
stmt = select(func.count(User.id))
with session_maker() as s:
result = s.execute(stmt)
assert result.scalar_one() == 1
74 changes: 74 additions & 0 deletions tests/test_sa2/test_views_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any, AsyncGenerator

import pytest
from httpx import AsyncClient
from sqlalchemy import Integer, String, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
MappedAsDataclass,
mapped_column,
)
from starlette.applications import Starlette

from sqladmin import Admin
from sqladmin.models import ModelView
from tests.common import async_engine

async_session_maker = async_sessionmaker(
bind=async_engine,
class_=AsyncSession,
expire_on_commit=False,
)

pytestmark = pytest.mark.anyio


class Base(MappedAsDataclass, DeclarativeBase):
pass


class User(Base):
__tablename__ = "users"

id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
name: Mapped[str] = mapped_column(String(length=16), init=True)
email: Mapped[str] = mapped_column(String, unique=True)


class UserAdmin(ModelView, model=User):
column_list = ["name", "email"]
column_labels = {"name": "Name", "email": "Email"}


app = Starlette()
async_admin = Admin(app=app, engine=async_engine)
async_admin.add_view(UserAdmin)


@pytest.fixture
async def async_prepare_database() -> AsyncGenerator[None, None]:
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)

await async_engine.dispose()


@pytest.fixture
async def async_client(
async_prepare_database: Any,
) -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://testserver") as c:
yield c


async def test_async_create_dataclass(async_client: AsyncClient) -> None:
await async_client.post("/admin/user/create", data={"name": "foo", "email": "bar"})
stmt = select(func.count(User.id))
async with async_session_maker() as s:
result = await s.execute(stmt)
assert result.scalar_one() == 1
Loading