Skip to content

Commit

Permalink
Add edit_form_query method (#745)
Browse files Browse the repository at this point in the history
Co-authored-by: Amin Alaee <[email protected]>
  • Loading branch information
lukeclimen and aminalaee authored May 6, 2024
1 parent 9ed5414 commit c28aa36
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/api_reference/model_view.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- count_query
- search_query
- sort_query
- edit_form_query
- on_model_change
- after_model_change
- on_model_delete
Expand Down
1 change: 1 addition & 0 deletions docs/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ The forms are based on `WTForms` package and include the following options:
* `form_include_pk`: Control if primary key column should be included in create/edit forms. Default is `False`.
* `form_ajax_refs`: Use Ajax with Select2 for loading relationship models async. This is use ful when the related model has a lot of records.
* `form_converter`: Allow adding custom converters to support additional column types.
* `edit_form_query`: A method with the signature of `(request) -> stmt` which can customize the edit form data.

!!! example

Expand Down
19 changes: 19 additions & 0 deletions docs/cookbook/optimize_relationship_loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,22 @@ which should be available in the form.
class ParentAdmin(ModelView, model=Parent):
form_excluded_columns = [Parent.children]
```

### Using `edit_form_query` to customize the edit form data

If you would like to fully customize the query to populate the edit object form, you may override
the `edit_form_query` function with your own SQLAlchemy query. In the following example, overriding
the default query will allow you to filter relationships to show only related children of the parent.

```py
class ParentAdmin(ModelView, model=Parent):
def edit_form_query(self, request: Request) -> Select:
parent_id = request.path_params["pk"]
return (
super()
.edit_form_query(request)
.join(Child)
.options(contains_eager(Parent.children))
.filter(Child.parent_id == parent_id)
)
```
2 changes: 1 addition & 1 deletion sqladmin/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ async def edit(self, request: Request) -> Response:
identity = request.path_params["identity"]
model_view = self._find_model_view(identity)

model = await model_view.get_object_for_edit(request.path_params["pk"])
model = await model_view.get_object_for_edit(request)
if not model:
raise HTTPException(status_code=404)

Expand Down
20 changes: 14 additions & 6 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,12 +807,8 @@ async def get_object_for_details(self, value: Any) -> Any:

return await self._get_object_by_pk(stmt)

async def get_object_for_edit(self, value: Any) -> Any:
stmt = self._stmt_by_identifier(value)

for relation in self._form_relations:
stmt = stmt.options(joinedload(relation))

async def get_object_for_edit(self, request: Request) -> Any:
stmt = self.edit_form_query(request)
return await self._get_object_by_pk(stmt)

async def get_object_for_delete(self, value: Any) -> Any:
Expand Down Expand Up @@ -1045,6 +1041,18 @@ def list_query(self, request: Request) -> Select:

return select(self.model)

def edit_form_query(self, request: Request) -> Select:
"""
The SQLAlchemy select expression used for the edit form page which can be
customized. By default it will select the object by primary key(s) without any
additional filters.
"""

stmt = self._stmt_by_identifier(request.path_params["pk"])
for relation in self._form_relations:
stmt = stmt.options(joinedload(relation))
return stmt

def count_query(self, request: Request) -> Select:
"""
The SQLAlchemy select expression used for the count query
Expand Down
41 changes: 38 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from markupsafe import Markup
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.orm import contains_eager, declarative_base, relationship, sessionmaker
from sqlalchemy.sql.expression import Select
from starlette.applications import Starlette
from starlette.requests import Request
Expand Down Expand Up @@ -52,6 +52,7 @@ class Address(Base):
__tablename__ = "addresses"

id = Column(Integer, primary_key=True)
name = Column(String)
user_id = Column(Integer, ForeignKey("users.id"))

user = relationship("User", back_populates="addresses")
Expand Down Expand Up @@ -381,13 +382,47 @@ def list_query(self, request: Request) -> Select:
assert len(await view.get_model_objects(request)) == 1


async def test_edit_form_query() -> None:
session = session_maker()
batman = User(id=123, name="batman")
batcave = Address(user=batman, name="bat cave")
wayne_manor = Address(user=batman, name="wayne manor")
session.add(batman)
session.add(batcave)
session.add(wayne_manor)
session.commit()

class UserAdmin(ModelView, model=User):
async_engine = False
session_maker = session_maker

def edit_form_query(self, request: Request) -> Select:
return (
select(self.model)
.join(Address)
.options(contains_eager(User.addresses))
.filter(Address.name == "bat cave")
)

view = UserAdmin()

class RequestObject(object):
pass

request_object = RequestObject()
request_object.path_params = {"pk": 123}
user_obj = await view.get_object_for_edit(request_object)

assert len(user_obj.addresses) == 1


def test_model_columns_all_keyword() -> None:
class AddressAdmin(ModelView, model=Address):
column_list = "__all__"
column_details_list = "__all__"

assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"]
assert AddressAdmin().get_list_columns() == ["user", "id", "name", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "id", "name", "user_id"]


async def test_get_prop_value() -> None:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_models_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ async def _action_stub(self, request: Request) -> Response:
pks = request.query_params.get("pks", "")

obj_strs: List[str] = []

class RequestObject(object):
pass

for pk in pks.split(","):
obj = await self.get_object_for_edit(pk)
request_object = RequestObject()
request_object.path_params = {"pk": pk}
obj = await self.get_object_for_edit(request_object)

obj_strs.append(repr(obj))

Expand Down

0 comments on commit c28aa36

Please sign in to comment.