Skip to content

Commit

Permalink
bug: add handler for handling changed domain
Browse files Browse the repository at this point in the history
* remove setting of role.id as the added flush() will add the ID from DB
* domains: extend find_domain to accept ID
* domains: only flush if data not persisted
  • Loading branch information
carlinmack authored Nov 14, 2024
1 parent 24f9ac2 commit 92b8f6b
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions invenio_accounts/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flask import current_app
from flask_security import SQLAlchemyUserDatastore, user_confirmed
from invenio_db import db
from sqlalchemy import inspect
from sqlalchemy.orm import joinedload

from .models import Domain, Role, User
Expand Down Expand Up @@ -76,10 +77,16 @@ def commit(self):
def mark_changed(self, sid, uid=None, rid=None, model=None):
"""Save a user to the changed history."""
if model:
# add the ID to the model from the DB if needed
if not inspect(model).persistent:
self.db.session.flush()

if isinstance(model, User):
current_db_change_history.add_updated_user(sid, model.id)
elif isinstance(model, Role):
current_db_change_history.add_updated_role(sid, model.id)
elif isinstance(model, Domain):
current_db_change_history.add_updated_domain(sid, model.id)
elif uid:
# Deprecated - use model param instead (still used in e.g.
# UserFixture pytest-invenio)
Expand All @@ -91,32 +98,33 @@ def mark_changed(self, sid, uid=None, rid=None, model=None):
def update_role(self, role):
"""Updates roles."""
role = self.db.session.merge(role)
# This works because role defines it's own id - for users
# the same doesn't work because id is assigned on commit which
# hasn't happened yet.
self.mark_changed(id(self.db.session), model=role)
return role

def create_role(self, **kwargs):
"""Creates and returns a new role from the given parameters."""
role = super().create_role(**kwargs)
# This works because role defines it's own id - for users
# the same doesn't work because id is assigned on commit which
# hasn't happened yet.
if role.id is None:
role.id = role.name
self.mark_changed(id(self.db.session), model=role)
return role

def find_role_by_id(self, role_id):
"""Fetches roles searching by id."""
"""Fetches roles searching by ID."""
return db.session.query(self.role_model).filter_by(id=role_id).one_or_none()

def find_domain(self, domain):
"""Find a domain."""
def find_domain(self, domain_or_id):
"""Find a domain by value or ID."""
if isinstance(domain_or_id, str):
if domain_or_id.isdigit():
clause = Domain.id == int(domain_or_id)
else:
clause = Domain.domain == domain_or_id
elif isinstance(domain_or_id, int):
clause = Domain.id == domain_or_id
else:
raise ValueError("Expected string or int, received:", type(domain_or_id))
return (
db.session.query(Domain)
.filter_by(domain=domain)
.filter(clause)
.options(joinedload(Domain.category_name))
.one_or_none()
)
Expand Down

0 comments on commit 92b8f6b

Please sign in to comment.