diff --git a/invenio_accounts/datastore.py b/invenio_accounts/datastore.py index 95e0b6e1..a235b8f7 100644 --- a/invenio_accounts/datastore.py +++ b/invenio_accounts/datastore.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -12,6 +13,7 @@ from flask import current_app from flask_security import SQLAlchemyUserDatastore, user_confirmed +from invenio_db import db from sqlalchemy.orm import joinedload from .models import Domain, Role, User @@ -108,12 +110,13 @@ def create_role(self, **kwargs): def find_role_by_id(self, role_id): """Fetches roles searching by id.""" - return self.role_model.query.filter_by(id=role_id).one_or_none() + return db.session.query(self.role_model).filter_by(id=role_id).one_or_none() def find_domain(self, domain): """Find a domain.""" return ( - Domain.query.filter_by(domain=domain) + db.session.query(Domain) + .filter_by(domain=domain) .options(joinedload(Domain.category_name)) .one_or_none() ) diff --git a/invenio_accounts/models.py b/invenio_accounts/models.py index 7a411567..48ce3d1b 100644 --- a/invenio_accounts/models.py +++ b/invenio_accounts/models.py @@ -3,6 +3,7 @@ # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. # Copyright (C) 2022 KTH Royal Institute of Technology +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -415,12 +416,12 @@ def query_by_expired(cls): """Query to select all expired sessions.""" lifetime = current_app.permanent_session_lifetime expired_moment = datetime.utcnow() - lifetime - return cls.query.filter(cls.created < expired_moment) + return db.session.query(cls).filter(cls.created < expired_moment) @classmethod def query_by_user(cls, user_id): """Query to select user sessions.""" - return cls.query.filter_by(user_id=user_id) + return db.session.query(cls).filter_by(user_id=user_id) @classmethod def is_current(cls, sid_s): @@ -446,7 +447,9 @@ class UserIdentity(db.Model, Timestamp): @classmethod def get_user(cls, method, external_id): """Get the user for a given identity.""" - identity = cls.query.filter_by(id=external_id, method=method).one_or_none() + identity = ( + db.session.query(cls).filter_by(id=external_id, method=method).one_or_none() + ) if identity is not None: return identity.user return None @@ -474,13 +477,13 @@ def create(cls, user, method, external_id): def delete_by_external_id(cls, method, external_id): """Unlink a user from an external id.""" with db.session.begin_nested(): - cls.query.filter_by(id=external_id, method=method).delete() + db.session.query(cls).filter_by(id=external_id, method=method).delete() @classmethod def delete_by_user(cls, method, user): """Unlink a user from an external id.""" with db.session.begin_nested(): - cls.query.filter_by(id_user=user.id, method=method).delete() + db.session.query(cls).filter_by(id_user=user.id, method=method).delete() class DomainOrg(db.Model): @@ -538,7 +541,7 @@ def create(cls, label): @classmethod def get(cls, label): """Get a domain category.""" - return cls.query.filter_by(label=label).one_or_none() + return db.session.query(cls).filter_by(label=label).one_or_none() class Domain(db.Model, Timestamp): diff --git a/invenio_accounts/sessions.py b/invenio_accounts/sessions.py index 2315be02..af154b18 100644 --- a/invenio_accounts/sessions.py +++ b/invenio_accounts/sessions.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -134,7 +135,7 @@ def delete_session(sid_s): # Find and remove the corresponding SessionActivity entry if request and "_impersonator_id" not in session: with db.session.begin_nested(): - SessionActivity.query.filter_by(sid_s=sid_s).delete() + db.session.query(SessionActivity).filter_by(sid_s=sid_s).delete() return 1 @@ -148,7 +149,7 @@ def delete_user_sessions(user): for s in user.active_sessions: _sessionstore.delete(s.sid_s) - SessionActivity.query.filter_by(user=user).delete() + db.session.query(SessionActivity).filter_by(user=user).delete() return True diff --git a/invenio_accounts/tasks.py b/invenio_accounts/tasks.py index 2133f2e0..8bce572c 100644 --- a/invenio_accounts/tasks.py +++ b/invenio_accounts/tasks.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2018 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -64,12 +65,12 @@ def delete_ips(): datetime.utcnow() - current_app.config["ACCOUNTS_RETENTION_PERIOD"] ) - LoginInformation.query.filter( + db.session.query(LoginInformation).filter( LoginInformation.last_login_ip.isnot(None), LoginInformation.last_login_at < expiration_date, ).update({LoginInformation.last_login_ip: None}) - LoginInformation.query.filter( + db.session.query(LoginInformation).filter( LoginInformation.current_login_ip.isnot(None), LoginInformation.current_login_at < expiration_date, ).update({LoginInformation.current_login_ip: None}) diff --git a/invenio_accounts/views/security.py b/invenio_accounts/views/security.py index f73e3dd6..8bbaa816 100644 --- a/invenio_accounts/views/security.py +++ b/invenio_accounts/views/security.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2017-2018 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -48,9 +49,9 @@ def revoke_session(): sid_s = form.data["sid_s"] if ( - SessionActivity.query.filter_by( - user_id=current_user.get_id(), sid_s=sid_s - ).count() + db.session.query(SessionActivity) + .filter_by(user_id=current_user.get_id(), sid_s=sid_s) + .count() == 1 ): delete_session(sid_s=sid_s) diff --git a/tests/test_invenio_accounts.py b/tests/test_invenio_accounts.py index aa458ea5..298fc04b 100644 --- a/tests/test_invenio_accounts.py +++ b/tests/test_invenio_accounts.py @@ -151,7 +151,9 @@ def test_datastore_usercreate(app): ds.commit() u2 = ds.find_user(email="info@inveniosoftware.org") assert u1 == u2 - assert 1 == User.query.filter_by(email="info@inveniosoftware.org").count() + assert ( + 1 == db.session.query(User).filter_by(email="info@inveniosoftware.org").count() + ) def test_datastore_rolecreate(app): @@ -162,7 +164,7 @@ def test_datastore_rolecreate(app): ds.commit() r2 = ds.find_role("superuser") assert r1 == r2 - assert 1 == Role.query.filter_by(name="superuser").count() + assert 1 == db.session.query(Role).filter_by(name="superuser").count() def test_datastore_update_role(app): @@ -173,7 +175,7 @@ def test_datastore_update_role(app): ds.commit() r2 = ds.find_role("superuser") assert r1 == r2 - assert 1 == Role.query.filter_by(name="superuser").count() + assert 1 == db.session.query(Role).filter_by(name="superuser").count() assert r2.is_managed is True r1 = ds.update_role( @@ -186,8 +188,8 @@ def test_datastore_update_role(app): assert r1 == r2 assert r2.description == "updated description" assert r2.is_managed is False - assert 1 == Role.query.filter_by(name="megauser").count() - assert 0 == Role.query.filter_by(name="superuser").count() + assert 1 == db.session.query(Role).filter_by(name="megauser").count() + assert 0 == db.session.query(Role).filter_by(name="superuser").count() def test_datastore_assignrole(app): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 9d0e4aee..9cea23e7 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2018 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -16,6 +17,7 @@ from flask_login import login_required from flask_mail import Message from flask_security import url_for_security +from invenio_db import db from invenio_accounts.models import SessionActivity, User from invenio_accounts.tasks import clean_session_table, delete_ips, send_security_email @@ -79,7 +81,7 @@ def test(): password=user1.password_plaintext, ), ) - assert len(SessionActivity.query.all()) == 1 + assert len(db.session.query(SessionActivity).all()) == 1 sleep(15) with task_app.test_client() as client: @@ -90,11 +92,11 @@ def test(): password=user2.password_plaintext, ), ) - assert len(SessionActivity.query.all()) == 2 + assert len(db.session.query(SessionActivity).all()) == 2 sleep(10) clean_session_table.s().apply() - assert len(SessionActivity.query.all()) == 1 + assert len(db.session.query(SessionActivity).all()) == 1 protected_url = url_for("test") @@ -103,7 +105,7 @@ def test(): sleep(15) clean_session_table.s().apply() - assert len(SessionActivity.query.all()) == 0 + assert len(db.session.query(SessionActivity).all()) == 0 res = client.get(protected_url) # check if the user is really logout @@ -146,14 +148,14 @@ def test_delete_ips(task_app): delete_ips() - user = User.query.filter(User.id == user1.id).one() + user = db.session.query(User).filter(User.id == user1.id).one() assert user.last_login_ip is None assert user.current_login_ip is None - user = User.query.filter(User.id == user2.id).one() + user = db.session.query(User).filter(User.id == user2.id).one() assert user.last_login_ip is not None assert user.current_login_ip is not None - user = User.query.filter(User.id == user3.id).one() + user = db.session.query(User).filter(User.id == user3.id).one() assert user.last_login_ip is None assert user.current_login_ip is not None diff --git a/tests/test_views.py b/tests/test_views.py index e5ad772a..2134d058 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2024 CERN. +# Copyright (C) 2024 Graz University of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -122,12 +123,12 @@ def test_view_list_sessions(app): assert res.status_code == 200 # check session for user 1 is not in the list - sessions_1 = SessionActivity.query.filter_by(user_id=user1.id).all() + sessions_1 = db.session.query(SessionActivity).filter_by(user_id=user1.id).all() assert len(sessions_1) == 1 assert sessions_1[0].sid_s not in res.data.decode("utf-8") # check session for user 2 is in the list - sessions_2 = SessionActivity.query.filter_by(user_id=user2.id).all() + sessions_2 = db.session.query(SessionActivity).filter_by(user_id=user2.id).all() assert len(sessions_2) == 1 assert sessions_2[0].sid_s in res.data.decode("utf-8") @@ -136,9 +137,9 @@ def test_view_list_sessions(app): res = client.post(url, data={"sid_s": sessions_1[0].sid_s}) assert res.status_code == 302 assert ( - SessionActivity.query.filter_by( - user_id=user1.id, sid_s=sessions_1[0].sid_s - ).count() + db.session.query(SessionActivity) + .filter_by(user_id=user1.id, sid_s=sessions_1[0].sid_s) + .count() == 1 ) @@ -147,9 +148,9 @@ def test_view_list_sessions(app): res = client.post(url, data={"sid_s": sessions_2[0].sid_s}) assert res.status_code == 302 assert ( - SessionActivity.query.filter_by( - user_id=user1.id, sid_s=sessions_2[0].sid_s - ).count() + db.session.query(SessionActivity) + .filter_by(user_id=user1.id, sid_s=sessions_2[0].sid_s) + .count() == 0 )