diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9017d79..ff96fec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -15,55 +15,14 @@ on: branches: master schedule: # * is a special character in YAML so you have to quote this string - - cron: '0 3 * * 6' + - cron: "0 3 * * 6" workflow_dispatch: inputs: reason: - description: 'Reason' + description: "Reason" required: false - default: 'Manual trigger' + default: "Manual trigger" jobs: - Tests: - runs-on: ubuntu-20.04 - strategy: - matrix: - python-version: [3.8, 3.9] - requirements-level: [pypi] - db-service: [postgresql14] - search-service: [opensearch2] - - env: - DB: ${{ matrix.db-service }} - EXTRAS: tests,${{ matrix.search-service }} - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Generate dependencies - run: | - python -m pip install --upgrade pip setuptools py wheel requirements-builder - requirements-builder -e "$EXTRAS" --level=${{ matrix.requirements-level }} setup.py > .${{ matrix.requirements-level }}-${{ matrix.python-version }}-requirements.txt - - - name: Cache pip - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('.${{ matrix.requirements-level }}-${{ matrix.python-version }}-requirements.txt') }} - - - name: Install dependencies - run: | - pip install -r .${{ matrix.requirements-level }}-${{ matrix.python-version }}-requirements.txt - pip install ".[$EXTRAS]" - pip freeze - docker --version - docker-compose --version - - - name: Run tests - run: | - ./run-tests.sh + Python: + uses: inveniosoftware/workflows/.github/workflows/tests-python.yml@master diff --git a/invenio_banners/records/models.py b/invenio_banners/records/models.py index 0d05802..c9238d1 100644 --- a/invenio_banners/records/models.py +++ b/invenio_banners/records/models.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2020-2023 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -13,7 +14,6 @@ from flask import current_app from invenio_db import db from sqlalchemy import or_ -from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql import text from sqlalchemy_utils.models import Timestamp @@ -24,7 +24,6 @@ class BannerModel(db.Model, Timestamp): """Defines a message to show to users.""" __tablename__ = "banners" - __versioned__ = {"versioning": False} id = db.Column(db.Integer, primary_key=True) @@ -51,6 +50,7 @@ def create(cls, data): """Create a new banner.""" _categories = [t[0] for t in current_app.config["BANNERS_CATEGORIES"]] assert data.get("category") in _categories + with db.session.begin_nested(): obj = cls( message=data.get("message"), @@ -62,24 +62,24 @@ def create(cls, data): ) db.session.add(obj) - db.session.commit() return obj @classmethod def update(cls, data, id): """Update an existing banner.""" with db.session.begin_nested(): - cls.query.filter_by(id=id).update(data) - - db.session.commit() + # NOTE: + # with db.session.get(cls, id) the model itself would be + # returned and this classmethod would be called + db.session.query(cls).filter_by(id=id).update(data) @classmethod def get(cls, id): """Get banner by its id.""" - try: - return cls.query.filter_by(id=id).one() - except NoResultFound: - raise BannerNotExistsError(id) + if banner := db.session.get(cls, id): + return banner + + raise BannerNotExistsError(id) @classmethod def delete(cls, banner): @@ -87,15 +87,14 @@ def delete(cls, banner): with db.session.begin_nested(): db.session.delete(banner) - db.session.commit() - @classmethod def get_active(cls, url_path): """Return active banners.""" now = datetime.utcnow() query = ( - cls.query.filter(cls.active.is_(True)) + db.session.query(cls) + .filter(cls.active.is_(True)) .filter(cls.start_datetime <= now) .filter((cls.end_datetime.is_(None)) | (now <= cls.end_datetime)) ) @@ -113,16 +112,17 @@ def get_active(cls, url_path): @classmethod def search(cls, search_params, filters): """Filter banners accordingly to query params.""" - banners = ( - BannerModel.query.filter(or_(*filters)) - .order_by( - search_params["sort_direction"](text(",".join(search_params["sort"]))) - ) - .paginate( - page=search_params["page"], - per_page=search_params["size"], - error_out=False, - ) + if filters == []: + filtered = db.session.query(BannerModel).filter() + else: + filtered = db.session.query(BannerModel).filter(or_(*filters)) + + banners = filtered.order_by( + search_params["sort_direction"](text(",".join(search_params["sort"]))) + ).paginate( + page=search_params["page"], + per_page=search_params["size"], + error_out=False, ) return banners @@ -133,12 +133,11 @@ def disable_expired(cls): now = datetime.utcnow() query = ( - cls.query.filter(cls.active.is_(True)) + db.session.query(cls) + .filter(cls.active.is_(True)) .filter(cls.end_datetime.isnot(None)) .filter(cls.end_datetime < now) ) for old in query.all(): old.active = False - - db.session.commit() diff --git a/invenio_banners/resources/errors.py b/invenio_banners/resources/errors.py index b412fae..d231293 100644 --- a/invenio_banners/resources/errors.py +++ b/invenio_banners/resources/errors.py @@ -1,19 +1,19 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2023 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. """Errors.""" -from flask_resources import HTTPJSONException, create_error_handler - -from ..services.errors import BannerNotExistsError import marshmallow as ma from flask_resources import HTTPJSONException, create_error_handler from invenio_records_resources.errors import validation_error_to_list_errors +from ..services.errors import BannerNotExistsError + class HTTPJSONValidationException(HTTPJSONException): """HTTP exception serializing to JSON and reflecting Marshmallow errors.""" @@ -25,7 +25,7 @@ def __init__(self, exception): super().__init__(code=400, errors=validation_error_to_list_errors(exception)) -class ErrorHandlersMixin(): +class ErrorHandlersMixin: """Mixin to define error handlers.""" error_handlers = { diff --git a/invenio_banners/services/results.py b/invenio_banners/services/results.py index 929ce93..f04d2d7 100644 --- a/invenio_banners/services/results.py +++ b/invenio_banners/services/results.py @@ -1,14 +1,22 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2022-2023 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. """Service results.""" -from flask_sqlalchemy import Pagination + from invenio_records_resources.services.records.results import RecordItem, RecordList +try: + # flask_sqlalchemy<3.0.0 + from flask_sqlalchemy import Pagination +except ImportError: + # flask_sqlalchemy>=3.0.0 + from flask_sqlalchemy.pagination import Pagination + class BannerItem(RecordItem): """Single banner result.""" diff --git a/invenio_banners/services/service.py b/invenio_banners/services/service.py index ecfa336..3ed1993 100644 --- a/invenio_banners/services/service.py +++ b/invenio_banners/services/service.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2022-2023 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -10,6 +11,7 @@ import distutils.util import arrow +from invenio_db.uow import unit_of_work from invenio_records_resources.services import RecordService from invenio_records_resources.services.base import LinksTemplate from invenio_records_resources.services.base.utils import map_search_params @@ -81,7 +83,8 @@ def search(self, identity, params): links_item_tpl=self.links_item_tpl, ) - def create(self, identity, data, raise_errors=True): + @unit_of_work() + def create(self, identity, data, raise_errors=True, uow=None): """Create a banner.""" self.require_permission(identity, "create") @@ -99,17 +102,18 @@ def create(self, identity, data, raise_errors=True): self, identity, banner, links_tpl=self.links_item_tpl, errors=errors ) - def delete(self, identity, id): + @unit_of_work() + def delete(self, identity, id, uow=None): """Delete a banner from database.""" self.require_permission(identity, "delete") banner = self.record_cls.get(id) - self.record_cls.delete(banner) return self.result_item(self, identity, banner, links_tpl=self.links_item_tpl) - def update(self, identity, id, data): + @unit_of_work() + def update(self, identity, id, data, uow=None): """Update a banner.""" self.require_permission(identity, "update") @@ -131,7 +135,8 @@ def update(self, identity, id, data): links_tpl=self.links_item_tpl, ) - def disable_expired(self, identity): + @unit_of_work() + def disable_expired(self, identity, uow=None): """Disable expired banners.""" self.require_permission(identity, "disable") self.record_cls.disable_expired() diff --git a/tests/resources/test_resources.py b/tests/resources/test_resources.py index ef05ae7..e0f76e5 100644 --- a/tests/resources/test_resources.py +++ b/tests/resources/test_resources.py @@ -1,78 +1,88 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2022 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. """Banner resource tests.""" -from datetime import date, datetime +from datetime import date, datetime, timedelta import pytest +from invenio_db import db from invenio_records_resources.services.errors import PermissionDeniedError from invenio_banners.records import BannerModel banners = { + "banner0": { + "message": "banner0", + "url_path": "/banner0", + "category": "info", + "active": True, + "start_datetime": date(2022, 7, 20).strftime("%Y-%m-%d %H:%M:%S"), + "end_datetime": (datetime.utcnow() - timedelta(days=20)).strftime( + "%Y-%m-%d %H:%M:%S" + ), + }, "banner1": { "message": "banner1", "url_path": "/banner1", "category": "info", "active": True, - "start_datetime": date(2022, 7, 20), - "end_datetime": date(2023, 1, 29), + "start_datetime": date(2022, 7, 20).strftime("%Y-%m-%d %H:%M:%S"), + "end_datetime": (datetime.utcnow() + timedelta(days=20)).strftime( + "%Y-%m-%d %H:%M:%S" + ), }, "banner2": { "message": "banner2", "url_path": "/banner2", "category": "other", "active": False, - "start_datetime": date(2022, 12, 15), - "end_datetime": date(2023, 1, 5), + "start_datetime": date(2022, 12, 15).strftime("%Y-%m-%d %H:%M:%S"), + "end_datetime": (datetime.utcnow() + timedelta(days=10)).strftime( + "%Y-%m-%d %H:%M:%S" + ), }, "banner3": { "message": "banner3", "url_path": "/banner3", "category": "warning", "active": True, - "start_datetime": date(2023, 1, 20), - "end_datetime": date(2023, 2, 25), + "start_datetime": date(2023, 1, 20).strftime("%Y-%m-%d %H:%M:%S"), + "end_datetime": (datetime.utcnow() + timedelta(days=30)).strftime( + "%Y-%m-%d %H:%M:%S" + ), }, } def _create_banner(client, data, headers, status_code=None): """Send POST request.""" - result = client.post( - "/banners/", - headers=headers, - json=data, - ) + result = client.post("/banners/", headers=headers, json=data) assert result.status_code == status_code return result def _update_banner(client, id, data, headers, status_code=None): """Send PUT request.""" - result = client.put( - "/banners/{0}".format(id), - headers=headers, - json=data, - ) + result = client.put(f"/banners/{id}", headers=headers, json=data) assert result.status_code == status_code return result def _delete_banner(client, id, headers, status_code=None): """Send DELETE request.""" - result = client.delete("/banners/{0}".format(id), headers=headers) + result = client.delete(f"/banners/{id}", headers=headers) assert result.status_code == status_code return result def _get_banner(client, id, status_code=None): """Send GET request.""" - result = client.get("/banners/{0}".format(id)) + result = client.get(f"/banners/{id}") assert result.status_code == status_code return result @@ -108,15 +118,15 @@ def test_create_banner(client, admin, headers): def test_disable_expired_after_create_action(client, admin, headers): """Disable expired banners after a create a banner action.""" # create banner first - banner1 = BannerModel.create(banners["banner1"]) - assert banner1.active is True + banner0 = BannerModel.create(banners["banner0"]) + assert banner0.active is True banner_data = banners["banner2"] admin.login(client) _create_banner(client, banner_data, headers, 201).json - expired_banner = BannerModel.get(banner1.id) + expired_banner = BannerModel.get(banner0.id) assert expired_banner.active is False @@ -145,9 +155,9 @@ def test_update_banner(client, admin, headers): def test_disable_expired_after_update_action(client, admin, headers): """Disable expired banners after an update a banner action.""" # create banner first - banner1 = BannerModel.create(banners["banner1"]) + banner0 = BannerModel.create(banners["banner0"]) banner2 = BannerModel.create(banners["banner2"]) - assert banner1.active is True + assert banner0.active is True admin.login(client) @@ -160,7 +170,7 @@ def test_disable_expired_after_update_action(client, admin, headers): _update_banner(client, banner2.id, new_data, headers, 200).json - expired_banner = BannerModel.get(banner1.id) + expired_banner = BannerModel.get(banner0.id) assert expired_banner.active is False @@ -200,7 +210,7 @@ def test_delete_banner(client, admin, headers): _delete_banner(client, banner.id, headers, 204) # check that it's not present in db - assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None + assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None def test_delete_is_forbidden(client, user, headers): diff --git a/tests/services/test_services.py b/tests/services/test_services.py index ceb5e25..6cc9214 100644 --- a/tests/services/test_services.py +++ b/tests/services/test_services.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2022 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio-Banners is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -10,6 +11,7 @@ from datetime import datetime, timedelta import pytest +from invenio_db import db from invenio_records_resources.services.errors import PermissionDeniedError from invenio_banners.proxies import current_banners_service as service @@ -21,7 +23,10 @@ "message": "active", "url_path": "/active", "category": "info", - "end_datetime": datetime.utcnow() + timedelta(days=1), + "start_datetime": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"), + "end_datetime": (datetime.utcnow() + timedelta(days=1)).strftime( + "%Y-%m-%d %H:%M:%S" + ), "active": True, }, "inactive": { @@ -127,7 +132,7 @@ def test_delete_banner(app, superuser_identity): service.delete(superuser_identity, banner.id) # check that it's not present in db - assert BannerModel.query.filter_by(id=banner.id).one_or_none() is None + assert db.session.query(BannerModel).filter_by(id=banner.id).one_or_none() is None def test_delete_is_forbidden(app, simple_user_identity): @@ -150,7 +155,7 @@ def test_delete_non_existing_banner(app, superuser_identity): def test_read_banner(app, simple_user_identity): """Read a banner by id.""" # create banner first - banner = BannerModel.create(banners["active"]) + banner = BannerModel.create(banners["other"]) banner_result = service.read(simple_user_identity, banner.id) @@ -207,11 +212,12 @@ def test_disable_expired_banners(app, superuser_identity): BannerModel.create(banners["expired"]) BannerModel.create(banners["active"]) - assert BannerModel.query.filter(BannerModel.active.is_(True)).count() == 2 - + assert ( + db.session.query(BannerModel).filter(BannerModel.active.is_(True)).count() == 2 + ) service.disable_expired(superuser_identity) - _banners = BannerModel.query.filter(BannerModel.active.is_(True)).all() + _banners = db.session.query(BannerModel).filter(BannerModel.active.is_(True)).all() assert len(_banners) == 1 assert _banners[0].message == "active"