diff --git a/invenio_banners/records/models.py b/invenio_banners/records/models.py index 5a4e807..b977512 100644 --- a/invenio_banners/records/models.py +++ b/invenio_banners/records/models.py @@ -71,7 +71,7 @@ def create(cls, data): def update(cls, data, id): """Update an existing banner.""" with db.session.begin_nested(): - cls.query.filter_by(id=id).update(data) + db.session.query(cls).filter_by(id=id).update(data) db.session.commit() @@ -79,7 +79,7 @@ def update(cls, data, id): def get(cls, id): """Get banner by its id.""" try: - return cls.query.filter_by(id=id).one() + return db.session.query(cls).filter_by(id=id).one() except NoResultFound: raise BannerNotExistsError(id) @@ -96,7 +96,8 @@ def get_active(cls, url_path): now = datetime.utcnow() query = ( - cls.query.filter(cls.active.is_(True)) + db.session.query(cls) + .filter(cls.active.is_(True)) .filter(cls.start_datetime <= now) .filter((cls.end_datetime.is_(None)) | (now <= cls.end_datetime)) ) @@ -114,16 +115,17 @@ def get_active(cls, url_path): @classmethod def search(cls, search_params, filters): """Filter banners accordingly to query params.""" - banners = ( - BannerModel.query.filter(or_(False, *filters)) - .order_by( - search_params["sort_direction"](text(",".join(search_params["sort"]))) - ) - .paginate( - page=search_params["page"], - per_page=search_params["size"], - error_out=False, - ) + if filters == []: + filtered = db.session.query(BannerModel).filter() + else: + filtered = db.session.query(BannerModel).filter(or_(*filters)) + + banners = filtered.order_by( + search_params["sort_direction"](text(",".join(search_params["sort"]))) + ).paginate( + page=search_params["page"], + per_page=search_params["size"], + error_out=False, ) return banners @@ -134,7 +136,8 @@ def disable_expired(cls): now = datetime.utcnow() query = ( - cls.query.filter(cls.active.is_(True)) + db.session.query(cls) + .filter(cls.active.is_(True)) .filter(cls.end_datetime.isnot(None)) .filter(cls.end_datetime < now) ) diff --git a/tests/resources/test_resources.py b/tests/resources/test_resources.py index c675343..d963d5d 100644 --- a/tests/resources/test_resources.py +++ b/tests/resources/test_resources.py @@ -10,6 +10,7 @@ from datetime import date, datetime import pytest +from invenio_db import db from invenio_records_resources.services.errors import PermissionDeniedError from invenio_banners.records import BannerModel @@ -193,7 +194,7 @@ def test_delete_banner(client, admin, headers): _delete_banner(client, banner.id, headers, 204) # check that it's not present in db - assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None + assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None def test_delete_is_forbidden(client, user, headers): diff --git a/tests/services/test_services.py b/tests/services/test_services.py index bb6cc1b..6a5bb06 100644 --- a/tests/services/test_services.py +++ b/tests/services/test_services.py @@ -11,6 +11,7 @@ from datetime import datetime, timedelta import pytest +from invenio_db import db from invenio_records_resources.services.errors import PermissionDeniedError from invenio_banners.proxies import current_banners_service as service @@ -131,7 +132,7 @@ def test_delete_banner(app, superuser_identity): service.delete(superuser_identity, banner.id) # check that it's not present in db - assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None + assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None def test_delete_is_forbidden(app, simple_user_identity): @@ -211,11 +212,12 @@ def test_disable_expired_banners(app, superuser_identity): BannerModel.create(banners["expired"]) BannerModel.create(banners["active"]) - assert BannerModel.query.filter(BannerModel.active.is_(True)).count() == 2 - + assert ( + db.session.query(BannerModel).filter(BannerModel.active.is_(True)).count() == 2 + ) service.disable_expired(superuser_identity) - _banners = BannerModel.query.filter(BannerModel.active.is_(True)).all() + _banners = db.session.query(BannerModel).filter(BannerModel.active.is_(True)).all() assert len(_banners) == 1 assert _banners[0].message == "active"