diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f3b7b1e..a0bf535f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,40 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## Version 0.18.0 - 2024-07-01 + +### Added + +* Add `form_rules`, `form_create_rules`, `form_edit_rules` by @aminalaee in https://github.com/aminalaee/sqladmin/pull/779 +* Add more docs for overriding default tempates by @jonocodes in https://github.com/aminalaee/sqladmin/pull/769 + +### Fixed +* Fix edit_form_query documentation example by @lukeclimen in https://github.com/aminalaee/sqladmin/pull/777 + +**Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.17.0...0.18.0 + +## Version 0.17.0 - 2024-05-13 + +### Added + +* Add field description to Create/Edit templates by @ngaranko in https://github.com/aminalaee/sqladmin/pull/722 +* Add edit_form_query method by @lukeclimen in https://github.com/aminalaee/sqladmin/pull/745 +* Validate page and pageSize query parameters by @BhuwanPandey in https://github.com/aminalaee/sqladmin/pull/752 + +### Fixed + +* Hide save and add another button from edit.html if can_create is False by @MaximZemskov in https://github.com/aminalaee/sqladmin/pull/742 +* Fix list page sort symbol by @aminalaee in https://github.com/aminalaee/sqladmin/pull/744 +* Move template files from `templates` to `templates/sqladmin` by @hasansezertasan in https://github.com/aminalaee/sqladmin/pull/748 +* Fix `form_args` default by @aminalaee in https://github.com/aminalaee/sqladmin/pull/756 +* Fix getting column python type by @aminalaee in https://github.com/aminalaee/sqladmin/pull/757 +* Fix File and Image fields checkbox and input by @aminalaee in https://github.com/aminalaee/sqladmin/pull/761 +* Switch relationship loading to selectionload by @aminalaee in https://github.com/aminalaee/sqladmin/pull/758 +* Fix DELETE call query params by @aminalaee in https://github.com/aminalaee/sqladmin/pull/763 + +**Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.16.1...0.17.0 + ## Version 0.16.1 - 2024-02-20 ### Fixed @@ -13,12 +47,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix sort by model attribute in https://github.com/aminalaee/sqladmin/pull/713 * Fix Category not respecting is_visible and is_accessible in https://github.com/aminalaee/sqladmin/pull/698 -## New Contributors -* @kostyaten made their first contribution in https://github.com/aminalaee/sqladmin/pull/677 -* @EnotShow made their first contribution in https://github.com/aminalaee/sqladmin/pull/703 -* @jonocodes made their first contribution in https://github.com/aminalaee/sqladmin/pull/707 -* @Neverfan1 made their first contribution in https://github.com/aminalaee/sqladmin/pull/698 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.16.0...0.16.1 ## Version 0.16.0 - 2023-11-14 @@ -47,10 +75,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Add customized sort query signature (#624) by @YarLikviD in https://github.com/aminalaee/sqladmin/pull/625 -## New Contributors -* @Toshakins made their first contribution in https://github.com/aminalaee/sqladmin/pull/626 -* @YarLikviD made their first contribution in https://github.com/aminalaee/sqladmin/pull/625 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.15.0...0.15.1 ## Version 0.15.0 - 2023-09-19 @@ -334,10 +358,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix missing browser tab title by @cuamckuu in https://github.com/aminalaee/sqladmin/pull/229 * Remove sourceMappingURL in JS files by @aminalaee in https://github.com/aminalaee/sqladmin/pull/231 -### New Contributors -* @ischaojie made their first contribution in https://github.com/aminalaee/sqladmin/pull/214 -* @cuamckuu made their first contribution in https://github.com/aminalaee/sqladmin/pull/222 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.1.11...0.1.12 ## Version 0.1.11 - 2022-06-23 @@ -371,10 +391,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix form fields order when specifying columns by @okapies in https://github.com/aminalaee/sqladmin/pull/184 * Fix ModelConverter when `impl` is not callable by @aminalaee in https://github.com/aminalaee/sqladmin/pull/186 -### New Contributors -* @pgrimaud made their first contribution in https://github.com/aminalaee/sqladmin/pull/161 -* @okapies made their first contribution in https://github.com/aminalaee/sqladmin/pull/183 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.1.9...0.1.10 ## Version 0.1.9 - 2022-05-27 @@ -391,10 +407,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Function signature typing, and renames by @dwreeves in https://github.com/aminalaee/sqladmin/pull/116 * Fix SQLModel UUID type by @aminalaee in https://github.com/aminalaee/sqladmin/pull/158 -### New Contributors -* @skarrok made their first contribution in https://github.com/aminalaee/sqladmin/pull/140 -* @colin99d made their first contribution in https://github.com/aminalaee/sqladmin/pull/150 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.1.8...0.1.9 ## Version 0.1.8 - 2022-04-19 @@ -412,10 +424,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix get_model_attr with column labels by @aminalaee in https://github.com/aminalaee/sqladmin/pull/128 * Delay call to `self.get_converter` to use `form_overrides` by @lovetoburnswhen in https://github.com/aminalaee/sqladmin/pull/129 -### New Contributors -* @tr11 made their first contribution in https://github.com/aminalaee/sqladmin/pull/114 -* @lovetoburnswhen made their first contribution in https://github.com/aminalaee/sqladmin/pull/129 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.1.7...0.1.8 ## Version 0.1.7 - 2022-03-22 @@ -433,10 +441,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Fix PostgreSQL UUID PrimaryKey by @aminalaee in https://github.com/aminalaee/sqladmin/pull/92 * Fix Source Code Link by @baurt in https://github.com/aminalaee/sqladmin/pull/95 -### New Contributors -* @baurt made their first contribution in https://github.com/aminalaee/sqladmin/pull/95 -* @dwreeves made their first contribution in https://github.com/aminalaee/sqladmin/pull/97 - **Full Changelog**: https://github.com/aminalaee/sqladmin/compare/0.1.6...0.1.7 ## Version 0.1.6 - 2022-03-09 diff --git a/docs/api_reference/model_view.md b/docs/api_reference/model_view.md index cdf22f23..a266566a 100644 --- a/docs/api_reference/model_view.md +++ b/docs/api_reference/model_view.md @@ -44,6 +44,10 @@ - form_include_pk - form_ajax_refs - form_converter + - form_edit_query + - form_rules + - form_create_rules + - form_edit_rules - column_type_formatters - list_query - count_query diff --git a/docs/configurations.md b/docs/configurations.md index 5c8c8020..96dacea5 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -200,6 +200,10 @@ The forms are based on `WTForms` package and include the following options: * `form_include_pk`: Control if primary key column should be included in create/edit forms. Default is `False`. * `form_ajax_refs`: Use Ajax with Select2 for loading relationship models async. This is use ful when the related model has a lot of records. * `form_converter`: Allow adding custom converters to support additional column types. +* `form_edit_query`: A method with the signature of `(request) -> stmt` which can customize the edit form data. +* `form_rules`: List of form rules to manage rendering and behaviour of form. +* `form_create_rules`: List of form rules to manage rendering and behaviour of form in create page. +* `form_edit_rules`: List of form rules to manage rendering and behaviour of form in edit page. !!! example @@ -216,6 +220,8 @@ The forms are based on `WTForms` package and include the following options: "order_by": ("id",), } } + form_create_rules = ["name", "password"] + form_edit_rules = ["name"] ``` ## Export options @@ -234,10 +240,10 @@ The export options can be set per model and includes the following options: The template files are built using Jinja2 and can be completely overridden in the configurations. The pages available are: -* `list_template`: Template to use for models list page. Default is `list.html`. -* `create_template`: Template to use for model creation page. Default is `create.html`. -* `details_template`: Template to use for model details page. Default is `details.html`. -* `edit_template`: Template to use for model edit page. Default is `edit.html`. +* `list_template`: Template to use for models list page. Default is `sqladmin/list.html`. +* `create_template`: Template to use for model creation page. Default is `sqladmin/create.html`. +* `details_template`: Template to use for model details page. Default is `sqladmin/details.html`. +* `edit_template`: Template to use for model edit page. Default is `sqladmin/edit.html`. !!! example diff --git a/docs/cookbook/optimize_relationship_loading.md b/docs/cookbook/optimize_relationship_loading.md index 181eb308..6a718a35 100644 --- a/docs/cookbook/optimize_relationship_loading.md +++ b/docs/cookbook/optimize_relationship_loading.md @@ -60,3 +60,21 @@ which should be available in the form. class ParentAdmin(ModelView, model=Parent): form_excluded_columns = [Parent.children] ``` + +### Using `form_edit_query` to customize the edit form data + +If you would like to fully customize the query to populate the edit object form, you may override +the `form_edit_query` function with your own SQLAlchemy query. In the following example, overriding +the default query will allow you to filter relationships to show only related children of the parent. + +```py +class ParentAdmin(ModelView, model=Parent): + def form_edit_query(self, request: Request) -> Select: + parent_id = request.path_params["pk"] + return ( + self._stmt_by_identifier(parent_id) + .join(Child) + .options(contains_eager(Parent.children)) + .filter(Child.parent_id == parent_id) + ) +``` diff --git a/docs/cookbook/using_wysiwyg.md b/docs/cookbook/using_wysiwyg.md index dab4ae7c..4f24a5a2 100644 --- a/docs/cookbook/using_wysiwyg.md +++ b/docs/cookbook/using_wysiwyg.md @@ -12,7 +12,7 @@ class Post(Base): - First create a `templates` directory in your project. - Then add a file `custom_edit.html` there with the following content: ```html title="custom_edit.html" -{% extends "edit.html" %} +{% extends "sqladmin/edit.html" %} {% block tail %} + {% endblock %} + + ``` + ## Customizing Jinja2 environment You can add custom environment options to use it on your custom templates. First set up a project: @@ -90,7 +105,7 @@ Usage in templates: ```python def value_is_filepath(value: Any) -> bool: return isinstance(value, str) and os.path.isfile(value) - + admin.templates.env.globals["value_is_filepath"] = value_is_filepath ``` diff --git a/docs/writing_custom_views.md b/docs/writing_custom_views.md index d5e01832..b85235b0 100644 --- a/docs/writing_custom_views.md +++ b/docs/writing_custom_views.md @@ -89,7 +89,7 @@ Next we update the `report.html` file in the `templates` directory with the foll !!! example ```html - {% extends "layout.html" %} + {% extends "sqladmin/layout.html" %} {% block content %}
diff --git a/mkdocs.yml b/mkdocs.yml index cc071cd0..427ea422 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - Using a request object: "cookbook/using_request_object.md" - Multiple databases: "cookbook/multiple_databases.md" - Using rich text editor: "cookbook/using_wysiwyg.md" + - Working with Passwords: "cookbook/working_with_passwords.md" - API Reference: - Application: "api_reference/application.md" - ModelView: "api_reference/model_view.md" diff --git a/sqladmin/__init__.py b/sqladmin/__init__.py index ac226526..de850b87 100644 --- a/sqladmin/__init__.py +++ b/sqladmin/__init__.py @@ -1,7 +1,7 @@ from sqladmin.application import Admin, action, expose from sqladmin.models import BaseView, ModelView -__version__ = "0.16.1" +__version__ = "0.18.0" __all__ = [ "Admin", diff --git a/sqladmin/_menu.py b/sqladmin/_menu.py index 579c76a5..7389de66 100644 --- a/sqladmin/_menu.py +++ b/sqladmin/_menu.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, List, Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING from starlette.datastructures import URL from starlette.requests import Request @@ -8,11 +10,11 @@ class ItemMenu: - def __init__(self, name: str, icon: Optional[str] = None) -> None: + def __init__(self, name: str, icon: str | None = None) -> None: self.name = name self.icon = icon - self.parent: Optional["ItemMenu"] = None - self.children: List["ItemMenu"] = [] + self.parent: "ItemMenu" | None = None + self.children: list["ItemMenu"] = [] def add_child(self, item: "ItemMenu") -> None: item.parent = self @@ -27,7 +29,7 @@ def is_accessible(self, request: Request) -> bool: def is_active(self, request: Request) -> bool: return False - def url(self, request: Request) -> Union[str, URL]: + def url(self, request: Request) -> str | URL: return "#" @property @@ -53,9 +55,9 @@ def type_(self) -> str: class ViewMenu(ItemMenu): def __init__( self, - view: Union["BaseView", "ModelView"], + view: "BaseView" | "ModelView", name: str, - icon: Optional[str] = None, + icon: str | None = None, ) -> None: super().__init__(name=name, icon=icon) self.view = view @@ -69,7 +71,7 @@ def is_accessible(self, request: Request) -> bool: def is_active(self, request: Request) -> bool: return self.view.identity == request.path_params.get("identity") - def url(self, request: Request) -> Union[str, URL]: + def url(self, request: Request) -> str | URL: if self.view.is_model: return request.url_for("admin:list", identity=self.view.identity) return request.url_for(f"admin:{self.view.identity}") @@ -85,7 +87,7 @@ def type_(self) -> str: class Menu: def __init__(self) -> None: - self.items: List[ItemMenu] = [] + self.items: list[ItemMenu] = [] def add(self, item: ItemMenu) -> None: # Only works for one-level menu diff --git a/sqladmin/_queries.py b/sqladmin/_queries.py index 3e3e1112..44ae2b9e 100644 --- a/sqladmin/_queries.py +++ b/sqladmin/_queries.py @@ -1,9 +1,11 @@ -from typing import TYPE_CHECKING, Any, Dict, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import anyio from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.orm import Session, selectinload from sqlalchemy.sql.expression import Select, and_, or_ from starlette.requests import Request @@ -24,7 +26,7 @@ class Query: def __init__(self, model_view: "ModelView") -> None: self.model_view = model_view - def _get_to_many_stmt(self, relation: MODEL_PROPERTY, values: List[Any]) -> Select: + def _get_to_many_stmt(self, relation: MODEL_PROPERTY, values: list[Any]) -> Select: target = relation.mapper.class_ target_pks = get_primary_keys(target) @@ -131,7 +133,7 @@ async def _set_attributes_async( setattr(obj, key, value) return obj - def _update_sync(self, pk: Any, data: Dict[str, Any], request: Request) -> Any: + def _update_sync(self, pk: Any, data: dict[str, Any], request: Request) -> Any: stmt = self.model_view._stmt_by_identifier(pk) with self.model_view.session_maker(expire_on_commit=False) as session: @@ -147,12 +149,12 @@ def _update_sync(self, pk: Any, data: Dict[str, Any], request: Request) -> Any: return obj async def _update_async( - self, pk: Any, data: Dict[str, Any], request: Request + self, pk: Any, data: dict[str, Any], request: Request ) -> Any: stmt = self.model_view._stmt_by_identifier(pk) for relation in self.model_view._form_relations: - stmt = stmt.options(joinedload(relation)) + stmt = stmt.options(selectinload(relation)) async with self.model_view.session_maker(expire_on_commit=False) as session: result = await session.execute(stmt) @@ -187,7 +189,7 @@ async def _delete_async(self, pk: str, request: Request) -> None: await session.commit() await self.model_view.after_model_delete(obj, request) - def _insert_sync(self, data: Dict[str, Any], request: Request) -> Any: + def _insert_sync(self, data: dict[str, Any], request: Request) -> Any: obj = self.model_view.model() with self.model_view.session_maker(expire_on_commit=False) as session: @@ -202,7 +204,7 @@ def _insert_sync(self, data: Dict[str, Any], request: Request) -> Any: ) return obj - async def _insert_async(self, data: Dict[str, Any], request: Request) -> Any: + async def _insert_async(self, data: dict[str, Any], request: Request) -> Any: obj = self.model_view.model() async with self.model_view.session_maker(expire_on_commit=False) as session: diff --git a/sqladmin/ajax.py b/sqladmin/ajax.py index 63f51175..28457810 100644 --- a/sqladmin/ajax.py +++ b/sqladmin/ajax.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Dict, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from sqlalchemy import String, cast, inspect, or_, select @@ -52,13 +54,13 @@ def _process_fields(self) -> list: return remote_fields - def format(self, model: type) -> Dict[str, Any]: + def format(self, model: type) -> dict[str, Any]: if not model: return {} return {"id": str(get_object_identifier(model)), "text": str(model)} - async def get_list(self, term: str, limit: int = DEFAULT_PAGE_SIZE) -> List[Any]: + async def get_list(self, term: str, limit: int = DEFAULT_PAGE_SIZE) -> list[Any]: stmt = select(self.model) # no type casting to string if a ColumnAssociationProxyInstance is given diff --git a/sqladmin/application.py b/sqladmin/application.py index 7f582690..c0660e61 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import io import logging @@ -7,23 +9,18 @@ Any, Awaitable, Callable, - List, - Optional, Sequence, - Tuple, - Type, - Union, cast, no_type_check, ) -from urllib.parse import urljoin +from urllib.parse import parse_qsl, urljoin from jinja2 import ChoiceLoader, FileSystemLoader, PackageLoader from sqlalchemy.engine import Engine from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session, sessionmaker from starlette.applications import Starlette -from starlette.datastructures import URL, FormData, UploadFile +from starlette.datastructures import URL, FormData, MultiDict, UploadFile from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request @@ -66,14 +63,14 @@ class BaseAdmin: def __init__( self, app: Starlette, - engine: Optional[ENGINE_TYPE] = None, - session_maker: Optional[sessionmaker] = None, + engine: ENGINE_TYPE | None = None, + session_maker: sessionmaker | None = None, base_url: str = "/admin", title: str = "Admin", - logo_url: Optional[str] = None, + logo_url: str | None = None, templates_dir: str = "templates", - middlewares: Optional[Sequence[Middleware]] = None, - authentication_backend: Optional[AuthenticationBackend] = None, + middlewares: Sequence[Middleware] | None = None, + authentication_backend: AuthenticationBackend | None = None, ) -> None: self.app = app self.engine = engine @@ -100,7 +97,7 @@ def __init__( self.admin = Starlette(middleware=middlewares) self.templates = self.init_templating_engine() - self._views: List[Union[BaseView, ModelView]] = [] + self._views: list[BaseView | ModelView] = [] self._menu = Menu() def init_templating_engine(self) -> Jinja2Templates: @@ -120,7 +117,7 @@ def init_templating_engine(self) -> Jinja2Templates: return templates @property - def views(self) -> List[Union[BaseView, ModelView]]: + def views(self) -> list[BaseView | ModelView]: """Get list of ModelView and BaseView instances lazily. Returns: @@ -136,7 +133,7 @@ def _find_model_view(self, identity: str) -> ModelView: raise HTTPException(status_code=404) - def add_view(self, view: Union[Type[ModelView], Type[BaseView]]) -> None: + def add_view(self, view: type[ModelView] | type[BaseView]) -> None: """Add ModelView or BaseView classes to Admin. This is a shortcut that will handle both `add_model_view` and `add_base_view`. """ @@ -149,10 +146,10 @@ def add_view(self, view: Union[Type[ModelView], Type[BaseView]]) -> None: def _find_decorated_funcs( self, - view: Type[Union[BaseView, ModelView]], - view_instance: Union[BaseView, ModelView], + view: type[BaseView | ModelView], + view_instance: BaseView | ModelView, handle_fn: Callable[ - [MethodType, Type[Union[BaseView, ModelView]], Union[BaseView, ModelView]], + [MethodType, type[BaseView | ModelView], BaseView | ModelView], None, ], ) -> None: @@ -164,8 +161,8 @@ def _find_decorated_funcs( def _handle_action_decorated_func( self, func: MethodType, - view: Type[Union[BaseView, ModelView]], - view_instance: Union[BaseView, ModelView], + view: type[BaseView | ModelView], + view_instance: BaseView | ModelView, ) -> None: if hasattr(func, "_action"): view_instance = cast(ModelView, view_instance) @@ -194,8 +191,8 @@ def _handle_action_decorated_func( def _handle_expose_decorated_func( self, func: MethodType, - view: Type[Union[BaseView, ModelView]], - view_instance: Union[BaseView, ModelView], + view: type[BaseView | ModelView], + view_instance: BaseView | ModelView, ) -> None: if hasattr(func, "_exposed"): self.admin.add_route( @@ -208,7 +205,7 @@ def _handle_expose_decorated_func( view.identity = getattr(func, "_identity") - def add_model_view(self, view: Type[ModelView]) -> None: + def add_model_view(self, view: type[ModelView]) -> None: """Add ModelView to the Admin. ???+ usage @@ -237,7 +234,7 @@ class UserAdmin(ModelView, model=User): self._views.append(view_instance) self._build_menu(view_instance) - def add_base_view(self, view: Type[BaseView]) -> None: + def add_base_view(self, view: type[BaseView]) -> None: """Add BaseView to the Admin. ???+ usage @@ -265,7 +262,7 @@ async def test_page(self, request: Request): self._views.append(view_instance) self._build_menu(view_instance) - def _build_menu(self, view: Union[ModelView, BaseView]) -> None: + def _build_menu(self, view: ModelView | BaseView) -> None: if view.category: menu = CategoryMenu(name=view.category) menu.add_child(ViewMenu(view=view, name=view.name, icon=view.icon)) @@ -338,15 +335,15 @@ class UserAdmin(ModelView, model=User): def __init__( self, app: Starlette, - engine: Optional[ENGINE_TYPE] = None, - session_maker: Optional[Union[sessionmaker, "async_sessionmaker"]] = None, + engine: ENGINE_TYPE | None = None, + session_maker: sessionmaker | "async_sessionmaker" | None = None, base_url: str = "/admin", title: str = "Admin", - logo_url: Optional[str] = None, - middlewares: Optional[Sequence[Middleware]] = None, + logo_url: str | None = None, + middlewares: Sequence[Middleware] | None = None, debug: bool = False, templates_dir: str = "templates", - authentication_backend: Optional[AuthenticationBackend] = None, + authentication_backend: AuthenticationBackend | None = None, ) -> None: """ Args: @@ -374,14 +371,14 @@ def __init__( async def http_exception( request: Request, exc: Exception - ) -> Union[Response, Awaitable[Response]]: + ) -> Response | Awaitable[Response]: assert isinstance(exc, HTTPException) context = { "status_code": exc.status_code, "message": exc.detail, } return await self.templates.TemplateResponse( - request, "error.html", context, status_code=exc.status_code + request, "sqladmin/error.html", context, status_code=exc.status_code ) routes = [ @@ -428,7 +425,7 @@ async def http_exception( async def index(self, request: Request) -> Response: """Index route which can be overridden to create dashboards.""" - return await self.templates.TemplateResponse(request, "index.html") + return await self.templates.TemplateResponse(request, "sqladmin/index.html") @login_required async def list(self, request: Request) -> Response: @@ -440,6 +437,14 @@ async def list(self, request: Request) -> Response: pagination = await model_view.list(request) pagination.add_pagination_urls(request.url) + if ( + pagination.page * pagination.page_size + > pagination.count + pagination.page_size + ): + raise HTTPException( + status_code=400, detail="Invalid page or pageSize parameter" + ) + context = {"model_view": model_view, "pagination": pagination} return await self.templates.TemplateResponse( request, model_view.list_template, context @@ -485,7 +490,11 @@ async def delete(self, request: Request) -> Response: await model_view.delete_model(request, pk) - return Response(content=str(request.url_for("admin:list", identity=identity))) + referer_url = URL(request.headers.get("referer", "")) + referer_params = MultiDict(parse_qsl(referer_url.query)) + url = URL(str(request.url_for("admin:list", identity=identity))) + url = url.include_query_params(**referer_params) + return Response(content=str(url)) @login_required async def create(self, request: Request) -> Response: @@ -497,6 +506,7 @@ async def create(self, request: Request) -> Response: model_view = self._find_model_view(identity) Form = await model_view.scaffold_form() + model_view._validate_form_class(model_view._form_create_rules, Form) form_data = await self._handle_form_data(request) form = Form(form_data) @@ -542,11 +552,12 @@ async def edit(self, request: Request) -> Response: identity = request.path_params["identity"] model_view = self._find_model_view(identity) - model = await model_view.get_object_for_edit(request.path_params["pk"]) + model = await model_view.get_object_for_edit(request) if not model: raise HTTPException(status_code=404) Form = await model_view.scaffold_form() + model_view._validate_form_class(model_view._form_edit_rules, Form) context = { "obj": model, "model_view": model_view, @@ -609,13 +620,13 @@ async def login(self, request: Request) -> Response: context = {} if request.method == "GET": - return await self.templates.TemplateResponse(request, "login.html") + return await self.templates.TemplateResponse(request, "sqladmin/login.html") ok = await self.authentication_backend.login(request) if not ok: context["error"] = "Invalid credentials." return await self.templates.TemplateResponse( - request, "login.html", context, status_code=400 + request, "sqladmin/login.html", context, status_code=400 ) return RedirectResponse(request.url_for("admin:index"), status_code=302) @@ -648,7 +659,7 @@ async def ajax_lookup(self, request: Request) -> Response: def get_save_redirect_url( self, request: Request, form: FormData, model_view: ModelView, obj: Any - ) -> Union[str, URL]: + ) -> str | URL: """ Get the redirect URL after a save action which is triggered from create/edit page. @@ -673,7 +684,7 @@ async def _handle_form_data(self, request: Request, obj: Any = None) -> FormData """ form = await request.form() - form_data: List[Tuple[str, Union[str, UploadFile]]] = [] + form_data: list[tuple[str, str | UploadFile]] = [] for key, value in form.multi_items(): if not isinstance(value, UploadFile): form_data.append((key, value)) @@ -714,8 +725,8 @@ def _denormalize_wtform_data(self, form_data: dict, obj: Any) -> dict: def expose( path: str, *, - methods: List[str] = ["GET"], - identity: Optional[str] = None, + methods: list[str] = ["GET"], + identity: str | None = None, include_in_schema: bool = True, ) -> Callable[..., Any]: """Expose View with information.""" @@ -734,8 +745,8 @@ def wrap(func): def action( name: str, - label: Optional[str] = None, - confirmation_message: Optional[str] = None, + label: str | None = None, + confirmation_message: str | None = None, *, include_in_schema: bool = True, add_in_detail: bool = True, diff --git a/sqladmin/authentication.py b/sqladmin/authentication.py index 50443bec..14723cb1 100644 --- a/sqladmin/authentication.py +++ b/sqladmin/authentication.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import functools import inspect -from typing import Any, Callable, Union +from typing import Any, Callable from starlette.middleware import Middleware from starlette.requests import Request @@ -33,7 +35,7 @@ async def logout(self, request: Request) -> bool: """ raise NotImplementedError() - async def authenticate(self, request: Request) -> Union[Response, bool]: + async def authenticate(self, request: Request) -> Response | bool: """Implement authenticate logic here. This method will be called for each incoming request to validate the authentication. diff --git a/sqladmin/fields.py b/sqladmin/fields.py index 3a2e8f29..2334039d 100644 --- a/sqladmin/fields.py +++ b/sqladmin/fields.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json import operator -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Generator from wtforms import Form, ValidationError, fields, widgets @@ -43,7 +45,7 @@ class IntervalField(fields.StringField): A text field which stores a `datetime.timedelta` object. """ - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if not valuelist: return @@ -57,19 +59,19 @@ def process_formdata(self, valuelist: List[str]) -> None: class SelectField(fields.SelectField): def __init__( self, - label: Optional[str] = None, - validators: Optional[list] = None, + label: str | None = None, + validators: list | None = None, coerce: type = str, - choices: Optional[Union[list, Callable]] = None, + choices: list | Callable | None = None, allow_blank: bool = False, - blank_text: Optional[str] = None, + blank_text: str | None = None, **kwargs: Any, ) -> None: super().__init__(label, validators, coerce, choices, **kwargs) self.allow_blank = allow_blank self.blank_text = blank_text or " " - def iter_choices(self) -> Generator[Tuple[str, str, bool, Dict], None, None]: + def iter_choices(self) -> Generator[tuple[str, str, bool, dict], None, None]: choices = self.choices or [] if self.allow_blank: @@ -86,7 +88,7 @@ def iter_choices(self) -> Generator[Tuple[str, str, bool, Dict], None, None]: {}, ) - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist: if valuelist[0] == "__None": self.data = None @@ -112,7 +114,7 @@ def _value(self) -> str: else: return "{}" - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist: value = valuelist[0] @@ -132,10 +134,10 @@ class QuerySelectField(fields.SelectFieldBase): def __init__( self, - data: Optional[list] = None, - label: Optional[str] = None, - validators: Optional[list] = None, - get_label: Optional[Union[Callable, str]] = None, + data: list | None = None, + label: str | None = None, + validators: list | None = None, + get_label: Callable | str | None = None, allow_blank: bool = False, blank_text: str = "", **kwargs: Any, @@ -153,11 +155,11 @@ def __init__( self.allow_blank = allow_blank self.blank_text = blank_text - self._data: Optional[tuple] - self._formdata: Optional[Union[str, List[str]]] + self._data: tuple | None + self._formdata: str | list[str] | None @property - def data(self) -> Optional[tuple]: + def data(self) -> tuple | None: if self._formdata is not None: for pk, _ in self._select_data: if pk == self._formdata: @@ -170,7 +172,7 @@ def data(self, data: tuple) -> None: self._data = data self._formdata = None - def iter_choices(self) -> Generator[Tuple[str, str, bool, Dict], None, None]: + def iter_choices(self) -> Generator[tuple[str, str, bool, dict], None, None]: if self.allow_blank: yield ("__None", self.blank_text, self.data is None, {}) @@ -186,7 +188,7 @@ def iter_choices(self) -> Generator[Tuple[str, str, bool, Dict], None, None]: for pk, label in self._select_data: yield (pk, self.get_label(label), str(pk) == primary_key, {}) - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist: if self.allow_blank and valuelist[0] == "__None": self.data = None @@ -220,9 +222,9 @@ class QuerySelectMultipleField(QuerySelectField): def __init__( self, - data: Optional[list] = None, - label: Optional[str] = None, - validators: Optional[list] = None, + data: list | None = None, + label: str | None = None, + validators: list | None = None, default: Any = None, **kwargs: Any, ) -> None: @@ -238,11 +240,11 @@ def __init__( "allow_blank=True does not do anything for QuerySelectMultipleField." ) self._invalid_formdata = False - self._formdata: Optional[List[str]] = None - self._data: Optional[tuple] = None + self._formdata: list[str] | None = None + self._data: tuple | None = None @property - def data(self) -> Optional[tuple]: + def data(self) -> tuple | None: formdata = self._formdata if formdata is not None: data = [] @@ -262,7 +264,7 @@ def data(self, data: tuple) -> None: self._data = data self._formdata = None - def iter_choices(self) -> Generator[Tuple[str, Any, bool, Dict], None, None]: + def iter_choices(self) -> Generator[tuple[str, Any, bool, dict], None, None]: if self.data is not None: primary_keys = ( self.data @@ -272,7 +274,7 @@ def iter_choices(self) -> Generator[Tuple[str, Any, bool, Dict], None, None]: for pk, label in self._select_data: yield (pk, self.get_label(label), pk in primary_keys, {}) - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: self._formdata = list(set(valuelist)) def pre_validate(self, form: Form) -> None: @@ -297,8 +299,8 @@ class AjaxSelectField(fields.SelectFieldBase): def __init__( self, loader: QueryAjaxModelLoader, - label: Optional[str] = None, - validators: Optional[list] = None, + label: str | None = None, + validators: list | None = None, allow_blank: bool = False, **kwargs: Any, ) -> None: @@ -339,9 +341,9 @@ class AjaxSelectMultipleField(fields.SelectFieldBase): def __init__( self, loader: QueryAjaxModelLoader, - label: Optional[str] = None, - validators: Optional[list] = None, - default: Optional[list] = None, + label: str | None = None, + validators: list | None = None, + default: list | None = None, allow_blank: bool = False, **kwargs: Any, ) -> None: @@ -349,7 +351,7 @@ def __init__( self.loader = loader self.allow_blank = allow_blank default = default or [] - self._formdata: Set[Any] = set() + self._formdata: set[Any] = set() super().__init__(label, validators, default=default, **kwargs) @@ -382,7 +384,7 @@ def pre_validate(self, form: Form) -> None: def process_formdata(self, valuelist: list) -> None: self.data = valuelist - def process_data(self, value: Optional[list]) -> None: + def process_data(self, value: list | None) -> None: self.data = value or [] diff --git a/sqladmin/forms.py b/sqladmin/forms.py index dff5b370..1328e516 100644 --- a/sqladmin/forms.py +++ b/sqladmin/forms.py @@ -1,18 +1,15 @@ """ The converters are from Flask-Admin project. """ +from __future__ import annotations + import enum import inspect import sys from typing import ( Any, Callable, - Dict, - List, - Optional, Sequence, - Tuple, - Type, TypeVar, Union, no_type_check, @@ -85,7 +82,7 @@ def __call__( self, model: type, prop: MODEL_PROPERTY, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], ) -> UnboundField: ... # pragma: no cover @@ -107,7 +104,7 @@ def _inner(func: T_CC) -> T_CC: class ModelConverterBase: - _converters: Dict[str, ConverterCallable] = {} + _converters: dict[str, ConverterCallable] = {} def __init__(self) -> None: super().__init__() @@ -128,12 +125,12 @@ async def _prepare_kwargs( self, prop: MODEL_PROPERTY, session_maker: sessionmaker, - field_args: Dict[str, Any], - field_widget_args: Dict[str, Any], + field_args: dict[str, Any], + field_widget_args: dict[str, Any], form_include_pk: bool, - label: Optional[str] = None, - loader: Optional[QueryAjaxModelLoader] = None, - ) -> Optional[Dict[str, Any]]: + label: str | None = None, + loader: QueryAjaxModelLoader | None = None, + ) -> dict[str, Any] | None: if not isinstance(prop, (RelationshipProperty, ColumnProperty)): return None @@ -169,7 +166,7 @@ def _prepare_column( if (column.primary_key or column.foreign_keys) and not form_include_pk: return None - default = getattr(column, "default", None) + default = getattr(column, "default", None) or kwargs.get("default") if default is not None: # Only actually change default if it has an attribute named @@ -205,7 +202,7 @@ async def _prepare_relationship( prop: RelationshipProperty, kwargs: dict, session_maker: sessionmaker, - loader: Optional[QueryAjaxModelLoader] = None, + loader: QueryAjaxModelLoader | None = None, ) -> dict: nullable = True for pair in prop.local_remote_pairs: @@ -225,7 +222,7 @@ async def _prepare_select_options( self, prop: RelationshipProperty, session_maker: sessionmaker, - ) -> List[Tuple[str, Any]]: + ) -> list[tuple[str, Any]]: target_model = prop.mapper.class_ stmt = select(target_model) @@ -283,13 +280,13 @@ async def convert( model: type, prop: MODEL_PROPERTY, session_maker: sessionmaker, - field_args: Dict[str, Any], - field_widget_args: Dict[str, Any], + field_args: dict[str, Any], + field_widget_args: dict[str, Any], form_include_pk: bool, - label: Optional[str] = None, - override: Optional[Type[Field]] = None, - form_ajax_refs: Dict[str, QueryAjaxModelLoader] = {}, - ) -> Optional[UnboundField]: + label: str | None = None, + override: type[Field] | None = None, + form_ajax_refs: dict[str, QueryAjaxModelLoader] = {}, + ) -> UnboundField: loader = form_ajax_refs.get(prop.key) kwargs = await self._prepare_kwargs( prop=prop, @@ -329,7 +326,7 @@ def _get_identifier_value(self, o: Any) -> str: class ModelConverter(ModelConverterBase): @staticmethod - def _string_common(prop: ColumnProperty) -> List[Validator]: + def _string_common(prop: ColumnProperty) -> list[Validator]: li = [] column: Column = prop.columns[0] if isinstance(column.type.length, int) and column.type.length: @@ -338,7 +335,7 @@ def _string_common(prop: ColumnProperty) -> List[Validator]: @converts("String", "CHAR") # includes Unicode def conv_string( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: extra_validators = self._string_common(prop) kwargs.setdefault("validators", []) @@ -347,7 +344,7 @@ def conv_string( @converts("Text", "LargeBinary", "Binary") # includes UnicodeText def conv_text( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) extra_validators = self._string_common(prop) @@ -356,7 +353,7 @@ def conv_text( @converts("Boolean", "dialects.mssql.base.BIT") def conv_boolean( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: if not prop.columns[0].nullable: kwargs.setdefault("render_kw", {}) @@ -370,25 +367,25 @@ def conv_boolean( @converts("Date") def conv_date( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return DateField(**kwargs) @converts("Time") def conv_time( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return TimeField(**kwargs) @converts("DateTime") def conv_datetime( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return DateTimeField(**kwargs) @converts("Enum") def conv_enum( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: available_choices = [(e, e) for e in prop.columns[0].type.enums] accepted_values = [choice[0] for choice in available_choices] @@ -408,13 +405,13 @@ def conv_enum( @converts("Integer") # includes BigInteger and SmallInteger def conv_integer( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return IntegerField(**kwargs) @converts("Numeric") # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE def conv_decimal( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: # override default decimal places limit, use database defaults instead kwargs.setdefault("places", None) @@ -422,13 +419,13 @@ def conv_decimal( @converts("JSON", "JSONB") def conv_json( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return JSONField(**kwargs) @converts("Interval") def conv_interval( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs["render_kw"]["placeholder"] = "Like: 1 day 1:25:33.652" return IntervalField(**kwargs) @@ -439,7 +436,7 @@ def conv_interval( "sqlalchemy_utils.types.ip_address.IPAddressType", ) def conv_ip_address( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(validators.IPAddress(ipv4=True, ipv6=True)) @@ -450,7 +447,7 @@ def conv_ip_address( "sqlalchemy.dialects.postgresql.types.MACADDR", ) def conv_mac_address( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(validators.MacAddress()) @@ -463,7 +460,7 @@ def conv_mac_address( "sqlalchemy_utils.types.uuid.UUIDType", ) def conv_uuid( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(validators.UUID()) @@ -473,13 +470,13 @@ def conv_uuid( "sqlalchemy.dialects.postgresql.base.ARRAY", "sqlalchemy.sql.sqltypes.ARRAY" ) def conv_ARRAY( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return Select2TagsField(**kwargs) @converts("sqlalchemy_utils.types.email.EmailType") def conv_email( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(validators.Email()) @@ -487,7 +484,7 @@ def conv_email( @converts("sqlalchemy_utils.types.url.URLType") def conv_url( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(validators.URL()) @@ -495,7 +492,7 @@ def conv_url( @converts("sqlalchemy_utils.types.currency.CurrencyType") def conv_currency( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(CurrencyValidator()) @@ -503,7 +500,7 @@ def conv_currency( @converts("sqlalchemy_utils.types.timezone.TimezoneType") def conv_timezone( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append( @@ -513,7 +510,7 @@ def conv_timezone( @converts("sqlalchemy_utils.types.phone_number.PhoneNumberType") def conv_phone_number( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(PhoneNumberValidator()) @@ -521,7 +518,7 @@ def conv_phone_number( @converts("sqlalchemy_utils.types.color.ColorType") def conv_color( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs.setdefault("validators", []) kwargs["validators"].append(ColorValidator()) @@ -530,7 +527,7 @@ def conv_color( @converts("sqlalchemy_utils.types.choice.ChoiceType") @no_type_check def convert_choice_type( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: available_choices = [] column = prop.columns[0] @@ -559,32 +556,32 @@ def convert_choice_type( @converts("fastapi_storages.integrations.sqlalchemy.FileType") def conv_file( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return FileField(**kwargs) @converts("fastapi_storages.integrations.sqlalchemy.ImageType") def conv_image( - self, model: type, prop: ColumnProperty, kwargs: Dict[str, Any] + self, model: type, prop: ColumnProperty, kwargs: dict[str, Any] ) -> UnboundField: return FileField(**kwargs) @converts("ONETOONE") def conv_one_to_one( - self, model: type, prop: RelationshipProperty, kwargs: Dict[str, Any] + self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any] ) -> UnboundField: kwargs["allow_blank"] = True return QuerySelectField(**kwargs) @converts("MANYTOONE") def conv_many_to_one( - self, model: type, prop: RelationshipProperty, kwargs: Dict[str, Any] + self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any] ) -> UnboundField: return QuerySelectField(**kwargs) @converts("MANYTOMANY", "ONETOMANY") def conv_many_to_many( - self, model: type, prop: RelationshipProperty, kwargs: Dict[str, Any] + self, model: type, prop: RelationshipProperty, kwargs: dict[str, Any] ) -> UnboundField: return QuerySelectMultipleField(**kwargs) @@ -592,17 +589,17 @@ def conv_many_to_many( async def get_model_form( model: type, session_maker: sessionmaker, - only: Optional[Sequence[str]] = None, - exclude: Optional[Sequence[str]] = None, - column_labels: Optional[Dict[str, str]] = None, - form_args: Optional[Dict[str, Dict[str, Any]]] = None, - form_widget_args: Optional[Dict[str, Dict[str, Any]]] = None, - form_class: Type[Form] = Form, - form_overrides: Optional[Dict[str, Type[Field]]] = None, - form_ajax_refs: Optional[Dict[str, QueryAjaxModelLoader]] = None, + only: Sequence[str] | None = None, + exclude: Sequence[str] | None = None, + column_labels: dict[str, str] | None = None, + form_args: dict[str, dict[str, Any]] | None = None, + form_widget_args: dict[str, dict[str, Any]] | None = None, + form_class: type[Form] = Form, + form_overrides: dict[str, type[Field]] | None = None, + form_ajax_refs: dict[str, QueryAjaxModelLoader] | None = None, form_include_pk: bool = False, - form_converter: Type[ModelConverterBase] = ModelConverter, -) -> Type[Form]: + form_converter: type[ModelConverterBase] = ModelConverter, +) -> type[Form]: type_name = model.__name__ + "Form" converter = form_converter() mapper = sqlalchemy_inspect(model) diff --git a/sqladmin/helpers.py b/sqladmin/helpers.py index 0d8cc518..99cc08c5 100644 --- a/sqladmin/helpers.py +++ b/sqladmin/helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import csv import enum import os @@ -9,11 +11,7 @@ Any, AsyncGenerator, Callable, - Dict, Generator, - List, - Optional, - Tuple, TypeVar, ) @@ -136,11 +134,11 @@ class Writer(ABC): """https://docs.python.org/3/library/csv.html#writer-objects""" @abstractmethod - def writerow(self, row: List[str]) -> None: + def writerow(self, row: list[str]) -> None: pass # pragma: no cover @abstractmethod - def writerows(self, rows: List[List[str]]) -> None: + def writerows(self, rows: list[list[str]]) -> None: pass # pragma: no cover @property @@ -174,7 +172,7 @@ def stream_to_csv( return callback(writer) # type: ignore -def get_primary_keys(model: Any) -> Tuple[Column, ...]: +def get_primary_keys(model: Any) -> tuple[Column, ...]: return tuple(inspect(model).mapper.primary_key) @@ -191,7 +189,7 @@ def get_object_identifier(obj: Any) -> Any: return ";".join(str(v).replace("\\", "\\\\").replace(";", r"\;") for v in values) -def _object_identifier_parts(id_string: str, model: type) -> Tuple[str, ...]: +def _object_identifier_parts(id_string: str, model: type) -> tuple[str, ...]: pks = get_primary_keys(model) if len(pks) == 1: # Only one primary key so no special processing @@ -241,10 +239,13 @@ def get_direction(prop: MODEL_PROPERTY) -> str: def get_column_python_type(column: Column) -> type: try: - if hasattr(column.type, "impl"): - return column.type.impl.python_type return column.type.python_type except NotImplementedError: + if hasattr(column.type, "impl"): + try: + return column.type.impl.python_type + except NotImplementedError: + ... return str @@ -252,7 +253,7 @@ def is_relationship(prop: MODEL_PROPERTY) -> bool: return isinstance(prop, RelationshipProperty) -def parse_interval(value: str) -> Optional[timedelta]: +def parse_interval(value: str) -> timedelta | None: match = ( standard_duration_re.match(value) or iso8601_duration_re.match(value) @@ -262,7 +263,7 @@ def parse_interval(value: str) -> Optional[timedelta]: if not match: return None - kw: Dict[str, Any] = match.groupdict() + kw: dict[str, Any] = match.groupdict() sign = -1 if kw.pop("sign", "+") == "-" else 1 if kw.get("microseconds"): kw["microseconds"] = kw["microseconds"].ljust(6, "0") diff --git a/sqladmin/models.py b/sqladmin/models.py index de3e1642..42ab23a3 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import time +import warnings from enum import Enum from typing import ( TYPE_CHECKING, @@ -20,14 +23,16 @@ import anyio from sqlalchemy import Column, String, asc, cast, desc, func, inspect, or_ from sqlalchemy.exc import NoInspectionAvailable -from sqlalchemy.orm import joinedload, sessionmaker +from sqlalchemy.orm import selectinload, sessionmaker from sqlalchemy.orm.exc import DetachedInstanceError from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.expression import Select, select from starlette.datastructures import URL +from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import StreamingResponse from wtforms import Field, Form +from wtforms.fields.core import UnboundField from sqladmin._queries import Query from sqladmin._types import MODEL_ATTR @@ -414,17 +419,17 @@ def formatter(model, attribute): """ # Templates - list_template: ClassVar[str] = "list.html" - """List view template. Default is `list.html`.""" + list_template: ClassVar[str] = "sqladmin/list.html" + """List view template. Default is `sqladmin/list.html`.""" - create_template: ClassVar[str] = "create.html" - """Create view template. Default is `create.html`.""" + create_template: ClassVar[str] = "sqladmin/create.html" + """Create view template. Default is `sqladmin/create.html`.""" - details_template: ClassVar[str] = "details.html" - """Details view template. Default is `details.html`.""" + details_template: ClassVar[str] = "sqladmin/details.html" + """Details view template. Default is `sqladmin/details.html`.""" - edit_template: ClassVar[str] = "edit.html" - """Edit view template. Default is `edit.html`.""" + edit_template: ClassVar[str] = "sqladmin/edit.html" + """Edit view template. Default is `sqladmin/edit.html`.""" # Export column_export_list: ClassVar[List[MODEL_ATTR]] = [] @@ -597,6 +602,28 @@ class UserAdmin(ModelAdmin, model=User): ``` """ + form_rules: ClassVar[list[str]] = [] + """List of rendering rules for model creation and edit form. + This property changes default form rendering behavior and to rearrange + order of rendered fields, add some text between fields, group them, etc. + If not set, will use default Flask-Admin form rendering logic. + + ???+ example + ```python + class UserAdmin(ModelAdmin, model=User): + form_rules = [ + "first_name", + "last_name", + ] + ``` + """ + + form_create_rules: ClassVar[list[str]] = [] + """Customized rules for the create form. Cannot be specified with `form_rules`.""" + + form_edit_rules: ClassVar[list[str]] = [] + """Customized rules for the edit form. Cannot be specified with `form_rules`.""" + # General options column_labels: ClassVar[Dict[MODEL_ATTR, str]] = {} """A mapping of column labels, used to map column names to new names. @@ -684,6 +711,8 @@ def __init__(self) -> None: model_admin=self, name=name, options=options ) + self._refresh_form_rules_cache() + self._custom_actions_in_list: Dict[str, str] = {} self._custom_actions_in_detail: Dict[str, str] = {} self._custom_actions_confirmation: Dict[str, str] = {} @@ -746,6 +775,17 @@ def _default_formatter(self, value: Any) -> Any: return value + def validate_page_number(self, number: Union[str, None], default: int) -> int: + if not number: + return default + + try: + return int(number) + except ValueError: + raise HTTPException( + status_code=400, detail="Invalid page or pageSize parameter" + ) + async def count(self, request: Request, stmt: Optional[Select] = None) -> int: if stmt is None: stmt = self.count_query(request) @@ -753,14 +793,14 @@ async def count(self, request: Request, stmt: Optional[Select] = None) -> int: return rows[0] async def list(self, request: Request) -> Pagination: - page = int(request.query_params.get("page", 1)) - page_size = int(request.query_params.get("pageSize", 0)) + page = self.validate_page_number(request.query_params.get("page"), 1) + page_size = self.validate_page_number(request.query_params.get("pageSize"), 0) page_size = min(page_size or self.page_size, max(self.page_size_options)) search = request.query_params.get("search", None) stmt = self.list_query(request) for relation in self._list_relations: - stmt = stmt.options(joinedload(relation)) + stmt = stmt.options(selectinload(relation)) stmt = self.sort_query(stmt, request) @@ -790,7 +830,7 @@ async def get_model_objects( stmt = self.list_query(request).limit(limit) for relation in self._list_relations: - stmt = stmt.options(joinedload(relation)) + stmt = stmt.options(selectinload(relation)) rows = await self._run_query(stmt) return rows @@ -803,16 +843,12 @@ async def get_object_for_details(self, value: Any) -> Any: stmt = self._stmt_by_identifier(value) for relation in self._details_relations: - stmt = stmt.options(joinedload(relation)) + stmt = stmt.options(selectinload(relation)) return await self._get_object_by_pk(stmt) - async def get_object_for_edit(self, value: Any) -> Any: - stmt = self._stmt_by_identifier(value) - - for relation in self._form_relations: - stmt = stmt.options(joinedload(relation)) - + async def get_object_for_edit(self, request: Request) -> Any: + stmt = self.form_edit_query(request) return await self._get_object_by_pk(stmt) async def get_object_for_delete(self, value: Any) -> Any: @@ -1045,6 +1081,25 @@ def list_query(self, request: Request) -> Select: return select(self.model) + def edit_form_query(self, request: Request) -> Select: + msg = ( + "Overriding 'edit_form_query' is deprecated. Use 'form_edit_query' instead." + ) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + return self.form_edit_query(request) + + def form_edit_query(self, request: Request) -> Select: + """ + The SQLAlchemy select expression used for the edit form page which can be + customized. By default it will select the object by primary key(s) without any + additional filters. + """ + + stmt = self._stmt_by_identifier(request.path_params["pk"]) + for relation in self._form_relations: + stmt = stmt.options(selectinload(relation)) + return stmt + def count_query(self, request: Request) -> Select: """ The SQLAlchemy select expression used for the count query @@ -1123,3 +1178,26 @@ async def generate(writer: Writer) -> AsyncGenerator[Any, None]: media_type="text/csv", headers={"Content-Disposition": f"attachment;filename={filename}"}, ) + + def _refresh_form_rules_cache(self) -> None: + if self.form_rules: + self._form_create_rules = self.form_rules + self._form_edit_rules = self.form_rules + else: + self._form_create_rules = self.form_create_rules + self._form_edit_rules = self.form_edit_rules + + def _validate_form_class(self, ruleset: List[Any], form_class: Type[Form]) -> None: + form_fields = [] + for name, obj in form_class.__dict__.items(): + if isinstance(obj, UnboundField): + form_fields.append(name) + + missing_fields = [] + if ruleset: + for field_name in form_fields: + if field_name not in ruleset: + missing_fields.append(field_name) + + for field_name in missing_fields: + delattr(form_class, field_name) diff --git a/sqladmin/pagination.py b/sqladmin/pagination.py index 36754ef8..cb6ba562 100644 --- a/sqladmin/pagination.py +++ b/sqladmin/pagination.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any, List +from typing import Any from starlette.datastructures import URL @@ -12,11 +14,11 @@ class PageControl: @dataclass class Pagination: - rows: List[Any] + rows: list[Any] page: int page_size: int count: int - page_controls: List[PageControl] = field(default_factory=list) + page_controls: list[PageControl] = field(default_factory=list) max_page_controls: int = 7 @property diff --git a/sqladmin/templates/index.html b/sqladmin/templates/index.html deleted file mode 100644 index 3f0c8612..00000000 --- a/sqladmin/templates/index.html +++ /dev/null @@ -1,3 +0,0 @@ -{% extends "layout.html" %} -{% block content %} -{% endblock %} \ No newline at end of file diff --git a/sqladmin/templates/_macros.html b/sqladmin/templates/sqladmin/_macros.html similarity index 69% rename from sqladmin/templates/_macros.html rename to sqladmin/templates/sqladmin/_macros.html index acd29481..51e09b03 100644 --- a/sqladmin/templates/_macros.html +++ b/sqladmin/templates/sqladmin/_macros.html @@ -53,3 +53,36 @@ {% endfor %}
{% endmacro %} + +{% macro render_field(field, kwargs={}) %} +
+ {{ field.label(class_="form-label col-sm-2 col-form-label") }} +
+ {% if field.errors %} + {{ field(class_="form-control is-invalid") }} + {% else %} + {{ field() }} + {% endif %} + {% for error in field.errors %} +
{{ error }}
+ {% endfor %} + {% if field.description %} + {{ field.description }} + {% endif %} +
+
+{% endmacro %} + +{% macro render_form_fields(form, form_opts=None) %} +{% if form.hidden_tag is defined %} +{{ form.hidden_tag() }} +{% else %} +{% for f in form if f.widget.input_type == 'hidden' %} +{{ f }} +{% endfor %} +{% endif %} + +{% for f in form if f.widget.input_type != 'hidden' %} +{{ render_field(f, kwargs) }} +{% endfor %} +{% endmacro %} \ No newline at end of file diff --git a/sqladmin/templates/base.html b/sqladmin/templates/sqladmin/base.html similarity index 100% rename from sqladmin/templates/base.html rename to sqladmin/templates/sqladmin/base.html diff --git a/sqladmin/templates/create.html b/sqladmin/templates/sqladmin/create.html similarity index 63% rename from sqladmin/templates/create.html rename to sqladmin/templates/sqladmin/create.html index e05521d7..e5557979 100644 --- a/sqladmin/templates/create.html +++ b/sqladmin/templates/sqladmin/create.html @@ -1,4 +1,5 @@ -{% extends "layout.html" %} +{% extends "sqladmin/layout.html" %} +{% from 'sqladmin/_macros.html' import render_form_fields %} {% block content %}
@@ -14,24 +15,7 @@

New {{ model_view.name }}

{% endif %}
- {% for field in form %} -
- {{ field.label(class_="form-label col-sm-2 col-form-label") }} -
- {% if field.errors %} - {{ field(class_="form-control is-invalid") }} - {% else %} - {{ field() }} - {% endif %} - {% for error in field.errors %} -
{{ error }}
- {% endfor %} - {% if field.description %} - {{ field.description }} - {% endif %} -
-
- {% endfor %} + {{ render_form_fields(form, form_opts=form_opts) }}
diff --git a/sqladmin/templates/details.html b/sqladmin/templates/sqladmin/details.html similarity index 96% rename from sqladmin/templates/details.html rename to sqladmin/templates/sqladmin/details.html index 34ae2c36..db7a1c90 100644 --- a/sqladmin/templates/details.html +++ b/sqladmin/templates/sqladmin/details.html @@ -1,4 +1,4 @@ -{% extends "layout.html" %} +{% extends "sqladmin/layout.html" %} {% block content %}
@@ -87,14 +87,14 @@

{% if model_view.can_delete %} -{% include 'modals/delete.html' %} +{% include 'sqladmin/modals/delete.html' %} {% endif %} {% for custom_action in model_view._custom_actions_in_detail %} {% if custom_action in model_view._custom_actions_confirmation %} {% with confirmation_message = model_view._custom_actions_confirmation[custom_action], custom_action=custom_action, url=model_view._url_for_action(request, custom_action) + '?pks=' + (get_object_identifier(model) | string) %} -{% include 'modals/details_action_confirmation.html' %} +{% include 'sqladmin/modals/details_action_confirmation.html' %} {% endwith %} {% endif %} {% endfor %} diff --git a/sqladmin/templates/edit.html b/sqladmin/templates/sqladmin/edit.html similarity index 67% rename from sqladmin/templates/edit.html rename to sqladmin/templates/sqladmin/edit.html index ae51125c..c84507d5 100644 --- a/sqladmin/templates/edit.html +++ b/sqladmin/templates/sqladmin/edit.html @@ -1,4 +1,5 @@ -{% extends "layout.html" %} +{% extends "sqladmin/layout.html" %} +{% from 'sqladmin/_macros.html' import render_form_fields %} {% block content %}
@@ -14,24 +15,7 @@

Edit {{ model_view.name }}

{% endif %}
- {% for field in form %} -
- {{ field.label(class_="form-label col-sm-2 col-form-label") }} -
- {% if field.errors %} - {{ field(class_="form-control is-invalid") }} - {% else %} - {{ field() }} - {% endif %} - {% for error in field.errors %} -
{{ error }}
- {% endfor %} - {% if field.description %} - {{ field.description }} - {% endif %} -
-
- {% endfor %} + {{ render_form_fields(form, form_opts=form_opts) }}
@@ -43,11 +27,13 @@

Edit {{ model_view.name }}

+ {% if model_view.can_create %} {% if model_view.save_as %} {% else %} {% endif %} + {% endif %}
diff --git a/sqladmin/templates/error.html b/sqladmin/templates/sqladmin/error.html similarity index 87% rename from sqladmin/templates/error.html rename to sqladmin/templates/sqladmin/error.html index c3dbdb5a..27f71e54 100644 --- a/sqladmin/templates/error.html +++ b/sqladmin/templates/sqladmin/error.html @@ -1,4 +1,4 @@ -{% extends "layout.html" %} +{% extends "sqladmin/layout.html" %} {% block body %}
diff --git a/sqladmin/templates/sqladmin/index.html b/sqladmin/templates/sqladmin/index.html new file mode 100644 index 00000000..26104d47 --- /dev/null +++ b/sqladmin/templates/sqladmin/index.html @@ -0,0 +1,3 @@ +{% extends "sqladmin/layout.html" %} +{% block content %} +{% endblock %} \ No newline at end of file diff --git a/sqladmin/templates/layout.html b/sqladmin/templates/sqladmin/layout.html similarity index 95% rename from sqladmin/templates/layout.html rename to sqladmin/templates/sqladmin/layout.html index 5a414f98..4e9fc6a1 100644 --- a/sqladmin/templates/layout.html +++ b/sqladmin/templates/sqladmin/layout.html @@ -1,5 +1,5 @@ -{% extends "base.html" %} -{% from '_macros.html' import display_menu %} +{% extends "sqladmin/base.html" %} +{% from 'sqladmin/_macros.html' import display_menu %} {% block body %}
-{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/sqladmin/templates/login.html b/sqladmin/templates/sqladmin/login.html similarity index 97% rename from sqladmin/templates/login.html rename to sqladmin/templates/sqladmin/login.html index 1afc62bf..dfb3e5a6 100644 --- a/sqladmin/templates/login.html +++ b/sqladmin/templates/sqladmin/login.html @@ -1,4 +1,4 @@ -{% extends "base.html" %} +{% extends "sqladmin/base.html" %} {% block body %}
diff --git a/sqladmin/templates/modals/delete.html b/sqladmin/templates/sqladmin/modals/delete.html similarity index 100% rename from sqladmin/templates/modals/delete.html rename to sqladmin/templates/sqladmin/modals/delete.html diff --git a/sqladmin/templates/modals/details_action_confirmation.html b/sqladmin/templates/sqladmin/modals/details_action_confirmation.html similarity index 100% rename from sqladmin/templates/modals/details_action_confirmation.html rename to sqladmin/templates/sqladmin/modals/details_action_confirmation.html diff --git a/sqladmin/templates/modals/list_action_confirmation.html b/sqladmin/templates/sqladmin/modals/list_action_confirmation.html similarity index 100% rename from sqladmin/templates/modals/list_action_confirmation.html rename to sqladmin/templates/sqladmin/modals/list_action_confirmation.html diff --git a/sqladmin/templating.py b/sqladmin/templating.py index d48d546d..3e175350 100644 --- a/sqladmin/templating.py +++ b/sqladmin/templating.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Mapping, Optional +from __future__ import annotations + +from typing import Any, Mapping import jinja2 from starlette.background import BackgroundTask @@ -13,11 +15,11 @@ def __init__( self, template: jinja2.Template, content: str, - context: Dict, + context: dict, status_code: int = 200, - headers: Optional[Mapping[str, str]] = None, - media_type: Optional[str] = None, - background: Optional[BackgroundTask] = None, + headers: Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, ): self.template = template self.context = context @@ -42,7 +44,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class Jinja2Templates: def __init__(self, directory: str) -> None: @jinja2.pass_context - def url_for(context: Dict, __name: str, **path_params: Any) -> URL: + def url_for(context: dict, __name: str, **path_params: Any) -> URL: request = context["request"] return request.url_for(__name, **path_params) @@ -54,7 +56,7 @@ async def TemplateResponse( self, request: Request, name: str, - context: Optional[Dict] = None, + context: dict | None = None, status_code: int = 200, ) -> _TemplateResponse: context = context or {} diff --git a/sqladmin/widgets.py b/sqladmin/widgets.py index ba313f1e..28f38b43 100644 --- a/sqladmin/widgets.py +++ b/sqladmin/widgets.py @@ -75,15 +75,23 @@ class FileInputWidget(widgets.FileInput): """ def __call__(self, field: Field, **kwargs: Any) -> str: - file_input = super().__call__(field, **kwargs) - checkbox_id = f"{field.id}_checkbox" - checkbox_label = Markup( - f'' - ) - checkbox_input = Markup( - f'' # noqa: E501 - ) - checkbox = Markup( - f'
{checkbox_input}{checkbox_label}
' - ) - return file_input + checkbox + if not field.flags.required: + checkbox_id = f"{field.id}_checkbox" + checkbox_label = Markup( + f'' + ) + checkbox_input = Markup( + f'' # noqa: E501 + ) + checkbox = Markup( + f'
{checkbox_input}{checkbox_label}
' + ) + else: + checkbox = Markup() + + if field.data: + current_value = Markup(f"

Currently: {field.data}

") + field.flags.required = False + return current_value + checkbox + super().__call__(field, **kwargs) + else: + return super().__call__(field, **kwargs) diff --git a/tests/templates/custom.html b/tests/templates/custom.html index 8b155f59..075b9467 100644 --- a/tests/templates/custom.html +++ b/tests/templates/custom.html @@ -1,4 +1,4 @@ -{% extends 'layout.html' %} +{% extends "sqladmin/layout.html" %} {% block content %}

Here I'm going to display some data.

{% endblock %} \ No newline at end of file diff --git a/tests/test_application.py b/tests/test_application.py index 0d9d6964..a7bf9710 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,3 +1,6 @@ +from typing import Generator + +import pytest from sqlalchemy import Column, Integer, String from sqlalchemy.orm import declarative_base from starlette.applications import Starlette @@ -27,6 +30,13 @@ class User(Base): name = Column(String(32), default="SQLAdmin") +@pytest.fixture(autouse=True) +def prepare_database() -> Generator[None, None, None]: + Base.metadata.create_all(engine) + yield + Base.metadata.drop_all(engine) + + def test_application_title() -> None: app = Starlette() Admin(app=app, engine=engine) @@ -153,3 +163,21 @@ class DataModelAdmin(ModelView, model=DataModel): assert admin._denormalize_wtform_data({"data_": "abcdef"}, datamodel) == { "data": "abcdef" } + + +def test_validate_page_and_page_size(): + app = Starlette() + admin = Admin(app=app, engine=engine) + + class UserAdmin(ModelView, model=User): + ... + + admin.add_view(UserAdmin) + + client = TestClient(app) + + response = client.get("/admin/user/list?page=10000") + assert response.status_code == 400 + + response = client.get("/admin/user/list?page=aaaa") + assert response.status_code == 400 diff --git a/tests/test_file_upload.py b/tests/test_file_upload.py index db379026..314192d8 100644 --- a/tests/test_file_upload.py +++ b/tests/test_file_upload.py @@ -1,21 +1,18 @@ -from typing import Any, AsyncGenerator +from typing import Any, Generator import pytest from fastapi_storages import FileSystemStorage, StorageFile -from fastapi_storages.integrations.sqlalchemy import FileType, ImageType -from httpx import AsyncClient +from fastapi_storages.integrations.sqlalchemy import FileType from sqlalchemy import Column, Integer, select -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import declarative_base, sessionmaker from starlette.applications import Starlette +from starlette.testclient import TestClient from sqladmin import Admin, ModelView -from tests.common import async_engine as engine - -pytestmark = pytest.mark.anyio +from tests.common import sync_engine as engine Base = declarative_base() # type: Any -session_maker = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) +session_maker = sessionmaker(bind=engine) app = Starlette() admin = Admin(app=app, engine=engine) @@ -25,24 +22,20 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) - file = Column(FileType(FileSystemStorage(".uploads"))) - image = Column(ImageType(FileSystemStorage(".uploads"))) + file = Column(FileType(FileSystemStorage(".uploads")), nullable=False) + optional_file = Column(FileType(FileSystemStorage(".uploads")), nullable=True) @pytest.fixture -async def prepare_database() -> AsyncGenerator[None, None]: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) +def prepare_database() -> Generator[None, None, None]: + Base.metadata.create_all(engine) yield - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.drop_all) - - await engine.dispose() + Base.metadata.drop_all(engine) @pytest.fixture -async def client(prepare_database: Any) -> AsyncGenerator[AsyncClient, None]: - async with AsyncClient(app=app, base_url="http://testserver") as c: +def client(prepare_database: Any) -> Generator[TestClient, None, None]: + with TestClient(app=app, base_url="http://testserver") as c: yield c @@ -53,66 +46,95 @@ class UserAdmin(ModelView, model=User): admin.add_view(UserAdmin) -async def _query_user() -> Any: +def _query_user() -> User: stmt = select(User).limit(1) - async with session_maker() as s: - result = await s.execute(stmt) - return result.scalar_one() + with session_maker() as s: + return s.scalar(stmt) -async def test_create_form_fields(client: AsyncClient) -> None: - response = await client.get("/admin/user/create") +def test_create_form_fields(client: TestClient) -> None: + response = client.get("/admin/user/create") assert response.status_code == 200 assert ( - '' + '' in response.text ) - assert 'Clear' + '' # noqa: E501 in response.text ) -async def test_create_form_post(client: AsyncClient) -> None: - files = {"file": ("upload.txt", b"abc")} - response = await client.post("/admin/user/create", files=files) +def test_create_form_post(client: TestClient) -> None: + files = { + "file": ("file.txt", b"abc"), + "optional_file": ("optional_file.txt", b"cdb"), + } + client.post("/admin/user/create", files=files) - user = await _query_user() + user = _query_user() - assert response.status_code == 302 assert isinstance(user.file, StorageFile) is True - assert user.file.name == "upload.txt" - assert user.file.path == ".uploads/upload.txt" + assert user.file.name == "file.txt" + assert user.file.path == ".uploads/file.txt" assert user.file.open().read() == b"abc" + assert user.optional_file.name == "optional_file.txt" + assert user.optional_file.path == ".uploads/optional_file.txt" + assert user.optional_file.open().read() == b"cdb" + + +def test_create_form_update(client: TestClient) -> None: + files = { + "file": ("file.txt", b"abc"), + "optional_file": ("optional_file.txt", b"cdb"), + } + client.post("/admin/user/create", files=files) + + files = { + "file": ("new_file.txt", b"xyz"), + "optional_file": ("new_optional_file.txt", b"zyx"), + } + client.post("/admin/user/edit/1", files=files) + + user = _query_user() + assert user.file.name == "new_file.txt" + assert user.file.path == ".uploads/new_file.txt" + assert user.file.open().read() == b"xyz" + assert user.optional_file.name == "new_optional_file.txt" + assert user.optional_file.path == ".uploads/new_optional_file.txt" + assert user.optional_file.open().read() == b"zyx" + + files = {"file": ("file.txt", b"abc")} + client.post( + "/admin/user/edit/1", files=files, data={"optional_file_checkbox": "true"} + ) + user = _query_user() + assert user.file.name == "file.txt" + assert user.file.path == ".uploads/file.txt" + assert user.file.open().read() == b"abc" + assert user.optional_file is None -async def test_create_form_update(client: AsyncClient) -> None: - files = {"file": ("upload.txt", b"abc")} - response = await client.post("/admin/user/create", files=files) - - user = await _query_user() - - files = {"file": ("new_upload.txt", b"abc")} - response = await client.post("/admin/user/edit/1", files=files) - - user = await _query_user() - assert response.status_code == 302 - assert user.file.name == "new_upload.txt" - assert user.file.path == ".uploads/new_upload.txt" - - files = {"file": ("empty.txt", b"")} - response = await client.post("/admin/user/edit/1", files=files) - user = await _query_user() - assert user.file.name == "new_upload.txt" - assert user.file.path == ".uploads/new_upload.txt" +def test_get_form_update(client: TestClient) -> None: + files = { + "file": ("file.txt", b"abc"), + "optional_file": ("optional_file.txt", b"cdb"), + } + client.post("/admin/user/create", files=files) + response = client.get("/admin/user/edit/1") - files = {"file": ("new_upload.txt", b"abc")} - response = await client.post( - "/admin/user/edit/1", files=files, data={"file_checkbox": True} + assert response.text.count("Currently:") == 2 + assert 'Clear' + in response.text ) - user = await _query_user() - assert user.file is None + files = {"file": ("file.txt", b"abc")} + client.post("/admin/user/edit/1", files=files) + response = client.get("/admin/user/edit/1") + + assert response.text.count("Currently:") == 1 + assert response.text.count("checkbox") == 0 diff --git a/tests/test_forms/test_forms.py b/tests/test_forms/test_forms.py index 6c0795ec..a3faf48e 100644 --- a/tests/test_forms/test_forms.py +++ b/tests/test_forms/test_forms.py @@ -150,11 +150,12 @@ async def test_model_form_exclude() -> None: async def test_model_form_form_args() -> None: - form_args = {"name": {"label": "User Name"}} + form_args = {"name": {"label": "User Name"}, "number": {"default": 100}} Form = await get_model_form( model=User, session_maker=session_maker, form_args=form_args ) assert Form()._fields["name"].label.text == "User Name" + assert Form()._fields["number"].default == 100 async def test_model_form_column_label() -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 81065dba..71a9f98f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ from markupsafe import Markup from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base, relationship, sessionmaker +from sqlalchemy.orm import contains_eager, declarative_base, relationship, sessionmaker from sqlalchemy.sql.expression import Select from starlette.applications import Starlette from starlette.requests import Request @@ -52,6 +52,7 @@ class Address(Base): __tablename__ = "addresses" id = Column(Integer, primary_key=True) + name = Column(String) user_id = Column(Integer, ForeignKey("users.id")) user = relationship("User", back_populates="addresses") @@ -381,13 +382,47 @@ def list_query(self, request: Request) -> Select: assert len(await view.get_model_objects(request)) == 1 +async def test_form_edit_query() -> None: + session = session_maker() + batman = User(id=123, name="batman") + batcave = Address(user=batman, name="bat cave") + wayne_manor = Address(user=batman, name="wayne manor") + session.add(batman) + session.add(batcave) + session.add(wayne_manor) + session.commit() + + class UserAdmin(ModelView, model=User): + async_engine = False + session_maker = session_maker + + def form_edit_query(self, request: Request) -> Select: + return ( + select(self.model) + .join(Address) + .options(contains_eager(User.addresses)) + .filter(Address.name == "bat cave") + ) + + view = UserAdmin() + + class RequestObject(object): + pass + + request_object = RequestObject() + request_object.path_params = {"pk": 123} + user_obj = await view.get_object_for_edit(request_object) + + assert len(user_obj.addresses) == 1 + + def test_model_columns_all_keyword() -> None: class AddressAdmin(ModelView, model=Address): column_list = "__all__" column_details_list = "__all__" - assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"] - assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"] + assert AddressAdmin().get_list_columns() == ["user", "id", "name", "user_id"] + assert AddressAdmin().get_details_columns() == ["user", "id", "name", "user_id"] async def test_get_prop_value() -> None: diff --git a/tests/test_models_action.py b/tests/test_models_action.py index 1d12e3c9..fb7c2ff8 100644 --- a/tests/test_models_action.py +++ b/tests/test_models_action.py @@ -35,8 +35,14 @@ async def _action_stub(self, request: Request) -> Response: pks = request.query_params.get("pks", "") obj_strs: List[str] = [] + + class RequestObject(object): + pass + for pk in pks.split(","): - obj = await self.get_object_for_edit(pk) + request_object = RequestObject() + request_object.path_params = {"pk": pk} + obj = await self.get_object_for_edit(request_object) obj_strs.append(repr(obj)) diff --git a/tests/test_views/test_view_sync.py b/tests/test_views/test_view_sync.py index 46274507..b1dc7d69 100644 --- a/tests/test_views/test_view_sync.py +++ b/tests/test_views/test_view_sync.py @@ -150,6 +150,8 @@ class UserAdmin(ModelView, model=User): User.profile_formattable: lambda m, a: f"Formatted {m.profile_formattable}", } save_as = True + form_create_rules = ["name", "email", "addresses", "profile", "birthdate", "status"] + form_edit_rules = ["name", "email", "addresses", "profile", "birthdate"] class AddressAdmin(ModelView, model=Address): @@ -442,6 +444,7 @@ def test_create_endpoint_get_form(client: TestClient) -> None: '' in response.text ) + assert '' not in response.text + ) response = client.get("/admin/address/edit/1")