diff --git a/docs/api_reference/model_view.md b/docs/api_reference/model_view.md index cdf22f23..87cd0b4e 100644 --- a/docs/api_reference/model_view.md +++ b/docs/api_reference/model_view.md @@ -49,6 +49,7 @@ - count_query - search_query - sort_query + - edit_form_query - on_model_change - after_model_change - on_model_delete diff --git a/docs/configurations.md b/docs/configurations.md index cfceffaf..10394ff0 100644 --- a/docs/configurations.md +++ b/docs/configurations.md @@ -200,6 +200,7 @@ The forms are based on `WTForms` package and include the following options: * `form_include_pk`: Control if primary key column should be included in create/edit forms. Default is `False`. * `form_ajax_refs`: Use Ajax with Select2 for loading relationship models async. This is use ful when the related model has a lot of records. * `form_converter`: Allow adding custom converters to support additional column types. +* `edit_form_query`: A method with the signature of `(request) -> stmt` which can customize the edit form data. !!! example diff --git a/docs/cookbook/optimize_relationship_loading.md b/docs/cookbook/optimize_relationship_loading.md index 181eb308..f7eb8fbb 100644 --- a/docs/cookbook/optimize_relationship_loading.md +++ b/docs/cookbook/optimize_relationship_loading.md @@ -60,3 +60,22 @@ which should be available in the form. class ParentAdmin(ModelView, model=Parent): form_excluded_columns = [Parent.children] ``` + +### Using `edit_form_query` to customize the edit form data + +If you would like to fully customize the query to populate the edit object form, you may override +the `edit_form_query` function with your own SQLAlchemy query. In the following example, overriding +the default query will allow you to filter relationships to show only related children of the parent. + +```py +class ParentAdmin(ModelView, model=Parent): + def edit_form_query(self, request: Request) -> Select: + parent_id = request.path_params["pk"] + return ( + super() + .edit_form_query(request) + .join(Child) + .options(contains_eager(Parent.children)) + .filter(Child.parent_id == parent_id) + ) +``` diff --git a/sqladmin/application.py b/sqladmin/application.py index d7b1d555..e6da646d 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -542,7 +542,7 @@ async def edit(self, request: Request) -> Response: identity = request.path_params["identity"] model_view = self._find_model_view(identity) - model = await model_view.get_object_for_edit(request.path_params["pk"]) + model = await model_view.get_object_for_edit(request) if not model: raise HTTPException(status_code=404) diff --git a/sqladmin/forms.py b/sqladmin/forms.py index dff5b370..c6c62f0a 100644 --- a/sqladmin/forms.py +++ b/sqladmin/forms.py @@ -169,7 +169,7 @@ def _prepare_column( if (column.primary_key or column.foreign_keys) and not form_include_pk: return None - default = getattr(column, "default", None) + default = getattr(column, "default", None) or kwargs.get("default") if default is not None: # Only actually change default if it has an attribute named diff --git a/tests/test_forms/test_forms.py b/tests/test_forms/test_forms.py index 6c0795ec..a3faf48e 100644 --- a/tests/test_forms/test_forms.py +++ b/tests/test_forms/test_forms.py @@ -150,11 +150,12 @@ async def test_model_form_exclude() -> None: async def test_model_form_form_args() -> None: - form_args = {"name": {"label": "User Name"}} + form_args = {"name": {"label": "User Name"}, "number": {"default": 100}} Form = await get_model_form( model=User, session_maker=session_maker, form_args=form_args ) assert Form()._fields["name"].label.text == "User Name" + assert Form()._fields["number"].default == 100 async def test_model_form_column_label() -> None: diff --git a/tests/test_models.py b/tests/test_models.py index 81065dba..c9e64318 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,7 +5,7 @@ from markupsafe import Markup from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import declarative_base, relationship, sessionmaker +from sqlalchemy.orm import contains_eager, declarative_base, relationship, sessionmaker from sqlalchemy.sql.expression import Select from starlette.applications import Starlette from starlette.requests import Request @@ -52,6 +52,7 @@ class Address(Base): __tablename__ = "addresses" id = Column(Integer, primary_key=True) + name = Column(String) user_id = Column(Integer, ForeignKey("users.id")) user = relationship("User", back_populates="addresses") @@ -381,13 +382,47 @@ def list_query(self, request: Request) -> Select: assert len(await view.get_model_objects(request)) == 1 +async def test_edit_form_query() -> None: + session = session_maker() + batman = User(id=123, name="batman") + batcave = Address(user=batman, name="bat cave") + wayne_manor = Address(user=batman, name="wayne manor") + session.add(batman) + session.add(batcave) + session.add(wayne_manor) + session.commit() + + class UserAdmin(ModelView, model=User): + async_engine = False + session_maker = session_maker + + def edit_form_query(self, request: Request) -> Select: + return ( + select(self.model) + .join(Address) + .options(contains_eager(User.addresses)) + .filter(Address.name == "bat cave") + ) + + view = UserAdmin() + + class RequestObject(object): + pass + + request_object = RequestObject() + request_object.path_params = {"pk": 123} + user_obj = await view.get_object_for_edit(request_object) + + assert len(user_obj.addresses) == 1 + + def test_model_columns_all_keyword() -> None: class AddressAdmin(ModelView, model=Address): column_list = "__all__" column_details_list = "__all__" - assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"] - assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"] + assert AddressAdmin().get_list_columns() == ["user", "id", "name", "user_id"] + assert AddressAdmin().get_details_columns() == ["user", "id", "name", "user_id"] async def test_get_prop_value() -> None: diff --git a/tests/test_models_action.py b/tests/test_models_action.py index 1d12e3c9..fb7c2ff8 100644 --- a/tests/test_models_action.py +++ b/tests/test_models_action.py @@ -35,8 +35,14 @@ async def _action_stub(self, request: Request) -> Response: pks = request.query_params.get("pks", "") obj_strs: List[str] = [] + + class RequestObject(object): + pass + for pk in pks.split(","): - obj = await self.get_object_for_edit(pk) + request_object = RequestObject() + request_object.path_params = {"pk": pk} + obj = await self.get_object_for_edit(request_object) obj_strs.append(repr(obj))