Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
global: apply new SQLAlchemy rules
Browse files Browse the repository at this point in the history
* change from model query to db.session.query(Model)

* this change is necessary to make the tests green otherwise the model
  would emit a commit event so that the changes are stored persistent
  into the database. the session.rollback() or also the
  transaction.rollback() does not work anymore
utnapischtim committed Oct 1, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 2a234ab commit 8da0d89
Showing 3 changed files with 25 additions and 19 deletions.
31 changes: 17 additions & 14 deletions invenio_banners/records/models.py
Original file line number Diff line number Diff line change
@@ -71,15 +71,15 @@ def create(cls, data):
def update(cls, data, id):
"""Update an existing banner."""
with db.session.begin_nested():
cls.query.filter_by(id=id).update(data)
db.session.query(cls).filter_by(id=id).update(data)

db.session.commit()

@classmethod
def get(cls, id):
"""Get banner by its id."""
try:
return cls.query.filter_by(id=id).one()
return db.session.query(cls).filter_by(id=id).one()
except NoResultFound:
raise BannerNotExistsError(id)

@@ -96,7 +96,8 @@ def get_active(cls, url_path):
now = datetime.utcnow()

query = (
cls.query.filter(cls.active.is_(True))
db.session.query(cls)
.filter(cls.active.is_(True))
.filter(cls.start_datetime <= now)
.filter((cls.end_datetime.is_(None)) | (now <= cls.end_datetime))
)
@@ -114,16 +115,17 @@ def get_active(cls, url_path):
@classmethod
def search(cls, search_params, filters):
"""Filter banners accordingly to query params."""
banners = (
BannerModel.query.filter(or_(False, *filters))
.order_by(
search_params["sort_direction"](text(",".join(search_params["sort"])))
)
.paginate(
page=search_params["page"],
per_page=search_params["size"],
error_out=False,
)
if filters == []:
filtered = db.session.query(BannerModel).filter()
else:
filtered = db.session.query(BannerModel).filter(or_(*filters))

banners = filtered.order_by(
search_params["sort_direction"](text(",".join(search_params["sort"])))
).paginate(
page=search_params["page"],
per_page=search_params["size"],
error_out=False,
)

return banners
@@ -134,7 +136,8 @@ def disable_expired(cls):
now = datetime.utcnow()

query = (
cls.query.filter(cls.active.is_(True))
db.session.query(cls)
.filter(cls.active.is_(True))
.filter(cls.end_datetime.isnot(None))
.filter(cls.end_datetime < now)
)
3 changes: 2 additions & 1 deletion tests/resources/test_resources.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from datetime import date, datetime

import pytest
from invenio_db import db
from invenio_records_resources.services.errors import PermissionDeniedError

from invenio_banners.records import BannerModel
@@ -193,7 +194,7 @@ def test_delete_banner(client, admin, headers):
_delete_banner(client, banner.id, headers, 204)

# check that it's not present in db
assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None
assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None


def test_delete_is_forbidden(client, user, headers):
10 changes: 6 additions & 4 deletions tests/services/test_services.py
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@
from datetime import datetime, timedelta

import pytest
from invenio_db import db
from invenio_records_resources.services.errors import PermissionDeniedError

from invenio_banners.proxies import current_banners_service as service
@@ -131,7 +132,7 @@ def test_delete_banner(app, superuser_identity):
service.delete(superuser_identity, banner.id)

# check that it's not present in db
assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None
assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None


def test_delete_is_forbidden(app, simple_user_identity):
@@ -211,11 +212,12 @@ def test_disable_expired_banners(app, superuser_identity):
BannerModel.create(banners["expired"])
BannerModel.create(banners["active"])

assert BannerModel.query.filter(BannerModel.active.is_(True)).count() == 2

assert (
db.session.query(BannerModel).filter(BannerModel.active.is_(True)).count() == 2
)
service.disable_expired(superuser_identity)

_banners = BannerModel.query.filter(BannerModel.active.is_(True)).all()
_banners = db.session.query(BannerModel).filter(BannerModel.active.is_(True)).all()

assert len(_banners) == 1
assert _banners[0].message == "active"

0 comments on commit 8da0d89

Please sign in to comment.