Skip to content

Commit

Permalink
Allow sort by related model field (#654)
Browse files Browse the repository at this point in the history
  • Loading branch information
aminalaee authored Oct 23, 2023
1 parent 5d0a2ec commit 8f9d07d
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
11 changes: 9 additions & 2 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,10 +1064,17 @@ def sort_query(self, stmt: Select, request: Request) -> Select:
sort_fields = self._get_default_sort()

for sort_field, is_desc in sort_fields:
model = self.model

parts = sort_field.split(".")
for part in parts[:-1]:
model = getattr(model, part).mapper.class_
stmt = stmt.join(model)

if is_desc:
stmt = stmt.order_by(desc(sort_field))
stmt = stmt.order_by(desc(getattr(model, parts[-1])))
else:
stmt = stmt.order_by(asc(sort_field))
stmt = stmt.order_by(asc(getattr(model, parts[-1])))

return stmt

Expand Down
46 changes: 32 additions & 14 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
from markupsafe import Markup
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, select
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from sqlalchemy.sql.expression import Select
Expand Down Expand Up @@ -51,7 +51,7 @@ def name_with_id(self) -> str:
class Address(Base):
__tablename__ = "addresses"

pk = Column(Integer, primary_key=True)
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))

user = relationship("User", back_populates="addresses")
Expand Down Expand Up @@ -121,9 +121,9 @@ class UserAdmin(ModelView, model=User):

def test_column_list_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_list = ["pk", "user_id"]
column_list = ["id", "user_id"]

assert AddressAdmin().get_list_columns() == ["pk", "user_id"]
assert AddressAdmin().get_list_columns() == ["id", "user_id"]


def test_column_list_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -242,9 +242,9 @@ class UserAdmin(ModelView, model=User):

def test_form_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
form_columns = ["pk", "user_id"]
form_columns = ["id", "user_id"]

assert AddressAdmin().get_form_columns() == ["pk", "user_id"]
assert AddressAdmin().get_form_columns() == ["id", "user_id"]


def test_form_columns_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -299,9 +299,9 @@ class UserAdmin(ModelView, model=User):

def test_export_columns_by_str_name() -> None:
class AddressAdmin(ModelView, model=Address):
column_export_list = ["pk", "user_id"]
column_export_list = ["id", "user_id"]

assert AddressAdmin().get_export_columns() == ["pk", "user_id"]
assert AddressAdmin().get_export_columns() == ["id", "user_id"]


def test_export_columns_both_include_and_exclude() -> None:
Expand Down Expand Up @@ -386,8 +386,8 @@ class AddressAdmin(ModelView, model=Address):
column_list = "__all__"
column_details_list = "__all__"

assert AddressAdmin().get_list_columns() == ["user", "pk", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "pk", "user_id"]
assert AddressAdmin().get_list_columns() == ["user", "id", "user_id"]
assert AddressAdmin().get_details_columns() == ["user", "id", "user_id"]


async def test_get_prop_value() -> None:
Expand All @@ -397,13 +397,10 @@ class ProfileAdmin(ModelView, model=Profile):
with session_maker() as session:
user = User(name="admin")
address = Address(user=user)
profile = Profile(
is_active=True, role=Role.ADMIN, status=Status.ACTIVE, user=user
)
profile = Profile(role=Role.ADMIN, status=Status.ACTIVE, user=user)
session.add_all([user, address, profile])
session.commit()

assert await ProfileAdmin().get_prop_value(profile, "is_active") is True
assert await ProfileAdmin().get_prop_value(profile, "role") == "ADMIN"
assert await ProfileAdmin().get_prop_value(profile, "status") == "ACTIVE"
assert await ProfileAdmin().get_prop_value(profile, "user.name") == "admin"
Expand All @@ -418,3 +415,24 @@ class UserAdmin(ModelView, model=User):
assert UserAdmin().get_list_columns() == ["id", "name", "name_with_id"]
assert UserAdmin().get_details_columns() == ["addresses", "profile", "id", "name"]
assert await UserAdmin().get_prop_value(user, "name_with_id") == "batman - 1"


def test_sort_query() -> None:
class AddressAdmin(ModelView, model=Address):
...

query = select(Address)

request = Request({"type": "http", "query_string": "sortBy=id&sort=asc"})
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY addresses.id ASC" in str(stmt)

request = Request({"type": "http", "query_string": b"sortBy=user.name&sort=desc"})
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY users.name DESC" in str(stmt)

request = Request(
{"type": "http", "query_string": b"sortBy=user.profile.role&sort=desc"}
)
stmt = AddressAdmin().sort_query(query, request)
assert "ORDER BY profiles.role DESC" in str(stmt)

0 comments on commit 8f9d07d

Please sign in to comment.