Skip to content

Commit

Permalink
datastore: add update_role func
Browse files Browse the repository at this point in the history
* closes inveniosoftware/invenio-app-rdm#2186
* updated cli to pass ids on create role

Co-authored-by: jrcastro2 <[email protected]>
  • Loading branch information
TLGINO and jrcastro2 committed Jun 5, 2023
1 parent 64659a8 commit f554095
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Change AccountsRole primary key to string."""

import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect

_dependant_modules_with_revisions = [("invenio_access", "04480be1593e"), ("invenio_communities", "fbe746957cfc")]


def _get_dependant_revisions():
"""Computes the revisions that this recipe depends on, based on the installed modules.
There are some revisions in some modules that need this recipe to be executed after some other revisions, due to
the relation that some tables have with the primary key of the Role table declared in this module. This "hack" is
to avoid direct dependency of this module with other ones that are not really required.
"""
dependant_revisions = []
for module_name, revision_id in _dependant_modules_with_revisions:
try:
__import__(module_name)
dependant_revisions.append(revision_id)
except ImportError as err:
pass
return tuple(dependant_revisions)


# revision identifiers, used by Alembic.
revision = "f2522cdd5fcd"
down_revision = "eb9743315a9d"
branch_labels = ()
depends_on = _get_dependant_revisions()


def upgrade():
"""Upgrade database."""
# Drop primary key and all foreign keys
op.execute("ALTER TABLE accounts_role DROP CONSTRAINT pk_accounts_role CASCADE")

op.alter_column(
"accounts_userrole",
"role_id",
existing_type=sa.Integer,
type_=sa.String(80),
postgresql_using="role_id::integer",
)
# Change primary key type
# server_default=None will remove the autoincrement
op.alter_column(
"accounts_role",
"id",
existing_type=sa.Integer,
type_=sa.String(80),
server_default=None,
)
op.create_primary_key("pk_accounts_role", "accounts_role", ["id"])
# Add new column `is_managed`
op.add_column(
"accounts_role",
sa.Column(
"is_managed", sa.Boolean(name="is_managed"), default=True, nullable=True
),
)
op.execute("UPDATE accounts_role SET is_managed = true")
op.alter_column("accounts_role", "is_managed", nullable=False)

# Re-create the foreign key constraint
op.create_foreign_key(
"fk_accounts_userrole_role_id",
"accounts_userrole",
"accounts_role",
["role_id"],
["id"],
)


def column_exists(table_name, column_name):
"""Checks if a column exists in a table."""
bind = op.get_context().bind
insp = inspect(bind)
columns = insp.get_columns(table_name)
return any(c["name"] == column_name for c in columns)


def downgrade():
"""Downgrade database."""
# We check for the existence of the column because if it's not there it means that this downgrade was already
# executed by invenio-access
if column_exists("accounts_role", "is_managed"):
# Drop new column `is_managed`
op.drop_column("accounts_role", "is_managed")
op.execute("ALTER TABLE accounts_role DROP CONSTRAINT pk_accounts_role CASCADE")

op.alter_column(
"accounts_userrole",
"role_id",
existing_type=sa.String(80),
type_=sa.Integer,
postgresql_using="role_id::integer",
)
# Change primary key type
# op.drop_constraint("pk_accounts_role", "accounts_role", type_="primary")
op.alter_column(
"accounts_role",
"id",
existing_type=sa.String(80),
type_=sa.Integer,
postgresql_using="id::integer",
)
op.create_primary_key("pk_accounts_role", "accounts_role", ["id"])
op.alter_column(
"accounts_role",
"id",
existing_type=sa.String(80),
type_=sa.Integer,
autoincrement=True,
existing_autoincrement=True,
nullable=False,
)

# Re-create the foreign key constraint
op.create_foreign_key(
"fk_accounts_userrole_role_id",
"accounts_userrole",
"accounts_role",
["role_id"],
["id"],
)
73 changes: 58 additions & 15 deletions invenio_accounts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,68 @@

"""API objects for Invenio Accounts."""

from collections import defaultdict

class Session:
"""Session object for DB Users change history."""

def __init__(self):
"""Constructor."""
self.updated_users = set()
self.updated_roles = set()
self.deleted_users = set()
self.deleted_roles = set()
self.indexed = False
self.invalidated_cache = False

def index(self):
"""Sets index to True."""
self.indexed = True

def invalidate_cache(self):
"""Sets invalidate_cache to True."""
self.invalidated_cache = True


class DBUsersChangeHistory:
"""DB Users change history storage."""

def __init__(self):
"""constructor."""
# the keys are going to be the sessions, the values are going to be
# the sets of dirty/deleted models
self.updated_users = defaultdict(lambda: list())
self.updated_roles = defaultdict(lambda: list())
self.deleted_users = defaultdict(lambda: list())
self.deleted_roles = defaultdict(lambda: list())

def _clear_dirty_sets(self, session):
"""Clear the dirty sets for the given session."""
"""Constructor."""
self.sessions = {}

def _get_session(self, session_id):
"""Returns or creates a session for a concrete session id."""
return self.sessions.setdefault(session_id, Session())

def add_updated_user(self, session_id, user_id):
"""Adds a user to the updated users list."""
session = self._get_session(session_id)
session.updated_users.add(user_id)

def add_updated_role(self, session_id, role_id):
"""Adds a role to the updated roles list."""
session = self._get_session(session_id)
session.updated_roles.add(role_id)

def add_deleted_user(self, session_id, user_id):
"""Adds a user to the deleted users list."""
session = self._get_session(session_id)
session.deleted_users.add(user_id)

def add_deleted_role(self, session_id, role_id):
"""Adds a role to the deleted roles list."""
session = self._get_session(session_id)
session.deleted_roles.add(role_id)

def clear_session(self, session_id):
"""Removes session object if it was indexed and the cache was invalidated."""
if (
self.sessions[session_id].indexed
and self.sessions[session_id].invalidated_cache
):
self.sessions.pop(session_id, None)

def clear_dirty_sets(self, session):
"""Removes session object."""
sid = id(session)
self.updated_users.pop(sid, None)
self.updated_roles.pop(sid, None)
self.deleted_users.pop(sid, None)
self.deleted_roles.pop(sid, None)
self.sessions.pop(sid, None)
2 changes: 1 addition & 1 deletion invenio_accounts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def users_create(email, password, active, confirm, profile):
@commit
def roles_create(**kwargs):
"""Create a role."""
_datastore.create_role(**kwargs)
_datastore.create_role(id=kwargs["name"], **kwargs)
click.secho('Role "%(name)s" created successfully.' % kwargs, fg="green")


Expand Down
21 changes: 19 additions & 2 deletions invenio_accounts/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from flask_security import SQLAlchemyUserDatastore

from .models import Role
from .proxies import current_db_change_history
from .sessions import delete_user_sessions
from .signals import datastore_post_commit, datastore_pre_commit
Expand Down Expand Up @@ -38,6 +39,22 @@ def commit(self):
def mark_changed(self, sid, uid=None, rid=None):
"""Save a user to the changed history."""
if uid:
current_db_change_history.updated_users[sid].append(uid)
current_db_change_history.add_updated_user(sid, uid)
elif rid:
current_db_change_history.updated_roles[sid].append(uid)
current_db_change_history.add_updated_role(sid, rid)

def update_role(self, role):
"""Merge roles."""
role = self.db.session.merge(role)
self.mark_changed(id(self.db.session), rid=role.id)
return role

def create_role(self, **kwargs):
"""Creates and returns a new role from the given parameters."""
role = super().create_role(**kwargs)
self.mark_changed(id(self.db.session), rid=role.id)
return role

def find_role_by_id(self, role_id):
"""Merge roles."""
return self.role_model.query.filter_by(id=role_id).one_or_none()
8 changes: 6 additions & 2 deletions invenio_accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

"""Database models for accounts."""

import uuid
from datetime import datetime

from flask import current_app, session
Expand Down Expand Up @@ -52,7 +53,7 @@
),
db.Column(
"role_id",
db.Integer(),
db.String(80),
db.ForeignKey("accounts_role.id", name="fk_accounts_userrole_role_id"),
),
)
Expand All @@ -64,14 +65,17 @@ class Role(db.Model, Timestamp, RoleMixin):

__tablename__ = "accounts_role"

id = db.Column(db.Integer(), primary_key=True)
id = db.Column(db.String(80), primary_key=True, default=str(uuid.uuid4()))

name = db.Column(db.String(80), unique=True)
"""Role name."""

description = db.Column(db.String(255))
"""Role description."""

is_managed = db.Column(db.Boolean(name="is_managed"), default=True, nullable=False)
"""True when the role is managed by Invenio, and not externally provided."""

# Enables SQLAlchemy version counter
version_id = db.Column(db.Integer, nullable=False)
"""Used by SQLAlchemy for optimistic concurrency control."""
Expand Down
8 changes: 5 additions & 3 deletions invenio_accounts/profiles/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def validate_visibility(value):

def validate_locale(value):
"""Check if the value is a valid locale."""
locales = current_app.extensions["invenio-i18n"].get_locales()
locales = [locale.language for locale in locales]
locales = current_app.extensions.get("invenio-i18n")
if locales:
locales = locales.get_locales()
locales = [locale.language for locale in locales]

if value not in locales:
if locales is not None and value not in locales:
raise ValidationError(message=str(_("Value must be a valid locale.")))
current_app.config["BABEL_DEFAULT_LOCALE"] = value

Expand Down
26 changes: 23 additions & 3 deletions run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,34 @@ set -o nounset
function cleanup() {
eval "$(docker-services-cli down --env)"
}
trap cleanup EXIT
# Check for arguments
# Note: "-k" would clash with "pytest"
keep_services=0
pytest_args=()
for arg in $@; do
# from the CLI args, filter out some known values and forward the rest to "pytest"
# note: we don't use "getopts" here b/c of some limitations (e.g. long options),
# which means that we can't combine short options (e.g. "./run-tests -Kk pattern")
case ${arg} in
-K|--keep-services)
keep_services=1
;;
*)
pytest_args+=( ${arg} )
;;
esac
done

if [[ ${keep_services} -eq 0 ]]; then
trap cleanup EXIT
fi

python -m check_manifest
python -m setup extract_messages --output-file /dev/null
python -m sphinx.cmd.build -qnN docs docs/_build/html
eval "$(docker-services-cli up --db ${DB:-postgresql} --cache ${CACHE:-redis} --env)"
python -m pytest
tests_exit_code=$?
# Note: expansion of pytest_args looks like below to not cause an unbound
# variable error when 1) "nounset" and 2) the array is empty.
python -m pytest ${pytest_args[@]+"${pytest_args[@]}"}
python -m sphinx.cmd.build -qnN -b doctest docs docs/_build/doctest
exit "$tests_exit_code"
23 changes: 22 additions & 1 deletion tests/test_invenio_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_datastore_usercreate(app):


def test_datastore_rolecreate(app):
"""Test create user."""
"""Test create role."""
ds = app.extensions["invenio-accounts"].datastore

with app.app_context():
Expand All @@ -136,6 +136,27 @@ def test_datastore_rolecreate(app):
assert 1 == Role.query.filter_by(name="superuser").count()


def test_datastore_update_role(app):
"""Test update role."""
ds = app.extensions["invenio-accounts"].datastore

with app.app_context():
r1 = ds.create_role(id="1", name="superuser", description="1234")
ds.commit()
r2 = ds.find_role("superuser")
assert r1 == r2
assert 1 == Role.query.filter_by(name="superuser").count()

r1 = ds.update_role(
Role(id="1", name="megauser", description="updated description")
)
ds.commit()
r2 = ds.find_role("megauser")
assert r1 == r2
assert r2.description == "updated description"
assert 1 == Role.query.filter_by(name="megauser").count()


def test_datastore_assignrole(app):
"""Create and assign user to role."""
ds = app.extensions["invenio-accounts"].datastore
Expand Down

0 comments on commit f554095

Please sign in to comment.