From c28aa361253d94ea5ee96975bf5d5e24e14f389e Mon Sep 17 00:00:00 2001 From: Luke Climenhage <56767690+lukeclimen@users.noreply.github.com> Date: Mon, 6 May 2024 07:25:14 -0400 Subject: [PATCH 1/2] Add `edit_form_query` method (#745) Co-authored-by: Amin Alaee --- docs/api_reference/model_view.md | 1 + docs/configurations.md | 1 + .../cookbook/optimize_relationship_loading.md | 19 +++++++++ sqladmin/application.py | 2 +- sqladmin/models.py | 20 ++++++--- tests/test_models.py | 41 +++++++++++++++++-- tests/test_models_action.py | 8 +++- 7 files changed, 81 insertions(+), 11 deletions(-) diff --git a/docs/api_reference/model_view.md b/docs/api_reference/model_view.md index cdf22f23..87cd0b4e 100644 --- a/docs/api_reference/model_view.md +++ b/docs/api_reference/model_view.md @@ -49,6 +49,7 @@ - count_query - search_query - sort_query + - edit_form_query - on_model_change - after_model_change - on_model_delete diff --git a/docs/configurations.md b/docs/configurations.md index cfceffaf..10394ff0 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -200,6 +200,7 @@ The forms are based on `WTForms` package and include the following options: * `form_include_pk`: Control if primary key column should be included in create/edit forms. Default is `False`. * `form_ajax_refs`: Use Ajax with Select2 for loading relationship models async. This is use ful when the related model has a lot of records. * `form_converter`: Allow adding custom converters to support additional column types. +* `edit_form_query`: A method with the signature of `(request) -> stmt` which can customize the edit form data. !!! example diff --git a/docs/cookbook/optimize_relationship_loading.md b/docs/cookbook/optimize_relationship_loading.md index 181eb308..f7eb8fbb 100644 --- a/docs/cookbook/optimize_relationship_loading.md +++ b/docs/cookbook/optimize_relationship_loading.md @@ -60,3 +60,22 @@ which should be available in the form. class ParentAdmin(ModelView, model=Parent): form_excluded_columns = [Parent.children] ``` + +### Using `edit_form_query` to customize the edit form data + +If you would like to fully customize the query to populate the edit object form, you may override +the `edit_form_query` function with your own SQLAlchemy query. In the following example, overriding +the default query will allow you to filter relationships to show only related children of the parent. + +```py +class ParentAdmin(ModelView, model=Parent): + def edit_form_query(self, request: Request) -> Select: + parent_id = request.path_params["pk"] + return ( + super() + .edit_form_query(request) + .join(Child) + .options(contains_eager(Parent.children)) + .filter(Child.parent_id == parent_id) + ) +``` diff --git a/sqladmin/application.py b/sqladmin/application.py index d7b1d555..e6da646d 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -542,7 +542,7 @@ async def edit(self, request: Request) -> Response: identity = request.path_params["identity"] model_view = self._find_model_view(identity) - model = await model_view.get_object_for_edit(request.path_params["pk"]) + model = await model_view.get_object_for_edit(request) if not model: raise HTTPException(status_code=404) diff --git a/sqladmin/models.py b/sqladmin/models.py index 61ee2a51..cbe88df9 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -807,12 +807,8 @@ async def get_object_for_details(self, value: Any) -> Any: return await self._get_object_by_pk(stmt) - async def get_object_for_edit(self, value: Any) -> Any: - stmt = self._stmt_by_identifier(value) - - for relation in self._form_relations: - stmt = stmt.options(joinedload(relation)) - + async def get_object_for_edit(self, request: Request) -> Any: + stmt = self.edit_form_query(request) return await self._get_object_by_pk(stmt) async def get_object_for_delete(self, value: Any) -> Any: @@ -1045,6 +1041,18 @@ def list_query(self, request: Request) -> Select: return select(self.model) + def edit_form_query(self, request: Request) -> Select: + """ + The SQLAlchemy select expression used for the edit form page which can be + customized. By default it will select the object by primary key(s) without any + additional filters. + """ + + stmt = self._stmt_by_identifier(request.path_params["pk"]) + for relation in self._form_relations: + stmt = stmt.options(joinedload(relation)) + return stmt + def count_query(self, request: Request) -> Select: """ The SQLAlchemy select expression used for the count query diff --git a/tests/test_models.py b/tests/test_models.py index 81065dba..c9e64318 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ from markupsafe import Markup from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base, relationship, sessionmaker +from sqlalchemy.orm import contains_eager, declarative_base, relationship, sessionmaker from sqlalchemy.sql.expression import Select from starlette.applications import Starlette from starlette.requests import Request @@ -52,6 +52,7 @@ class Address(Base): __tablename__ = "addresses" id = Column(Integer, primary_key=True) + name = Column(String) user_id = Column(Integer, ForeignKey("users.id")) user = relationship("User", back_populates="addresses") @@ -381,13 +382,47 @@ def list_query(self, request: Request) -> Select: assert len(await view.get_model_objects(request)) == 1 +async def test_edit_form_query() -> None: + session = session_maker() + batman = User(id=123, name="batman") + batcave = Address(user=batman, name="bat cave") + wayne_manor = Address(user=batman, name="wayne manor") + session.add(batman) + session.add(batcave) + session.add(wayne_manor) + session.commit() + + class UserAdmin(ModelView, model=User): + async_engine = False + session_maker = session_maker + + def edit_form_query(self, request: Request) -> Select: + return ( + select(self.model) + .join(Address) + .options(contains_eager(User.addresses)) + .filter(Address.name == "bat cave") + ) + + view = UserAdmin() + + class RequestObject(object): + pass + + request_object = RequestObject() + request_object.path_params = {"pk": 123} + user_obj = await view.get_object_for_edit(request_object) + + assert len(user_obj.addresses) == 1 + + def test_model_columns_all_keyword() -> None: class AddressAdmin(ModelView, model=Address): column_list = "__all__" column_details_list = "__all__" - assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"] - assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"] + assert AddressAdmin().get_list_columns() == ["user", "id", "name", "user_id"] + assert AddressAdmin().get_details_columns() == ["user", "id", "name", "user_id"] async def test_get_prop_value() -> None: diff --git a/tests/test_models_action.py b/tests/test_models_action.py index 1d12e3c9..fb7c2ff8 100644 --- a/tests/test_models_action.py +++ b/tests/test_models_action.py @@ -35,8 +35,14 @@ async def _action_stub(self, request: Request) -> Response: pks = request.query_params.get("pks", "") obj_strs: List[str] = [] + + class RequestObject(object): + pass + for pk in pks.split(","): - obj = await self.get_object_for_edit(pk) + request_object = RequestObject() + request_object.path_params = {"pk": pk} + obj = await self.get_object_for_edit(request_object) obj_strs.append(repr(obj)) From 0dc1e4db69883c84ebc86d4e14b9daf89550874e Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Mon, 6 May 2024 15:08:15 +0200 Subject: [PATCH 2/2] Fix `form_args` default (#756) --- sqladmin/forms.py | 2 +- tests/test_forms/test_forms.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sqladmin/forms.py b/sqladmin/forms.py index dff5b370..c6c62f0a 100644 --- a/sqladmin/forms.py +++ b/sqladmin/forms.py @@ -169,7 +169,7 @@ def _prepare_column( if (column.primary_key or column.foreign_keys) and not form_include_pk: return None - default = getattr(column, "default", None) + default = getattr(column, "default", None) or kwargs.get("default") if default is not None: # Only actually change default if it has an attribute named diff --git a/tests/test_forms/test_forms.py b/tests/test_forms/test_forms.py index 6c0795ec..a3faf48e 100644 --- a/tests/test_forms/test_forms.py +++ b/tests/test_forms/test_forms.py @@ -150,11 +150,12 @@ async def test_model_form_exclude() -> None: async def test_model_form_form_args() -> None: - form_args = {"name": {"label": "User Name"}} + form_args = {"name": {"label": "User Name"}, "number": {"default": 100}} Form = await get_model_form( model=User, session_maker=session_maker, form_args=form_args ) assert Form()._fields["name"].label.text == "User Name" + assert Form()._fields["number"].default == 100 async def test_model_form_column_label() -> None: