diff --git a/pytest_invenio/fixtures.py b/pytest_invenio/fixtures.py index 47e4603..bcfc8fa 100644 --- a/pytest_invenio/fixtures.py +++ b/pytest_invenio/fixtures.py @@ -15,6 +15,7 @@ import tempfile from datetime import datetime from warnings import warn +from invenio_db.utils import drop_alembic_version_table import importlib_metadata import pkg_resources @@ -875,3 +876,56 @@ def test_with_user(service, myuser): """ return UserFixtureBase + + +@pytest.fixture() +def test_alembic(): + """Test alembic recipes. + + This test is created with the purpose of emulating an instance upgrade. When using alembic to upgrade it should be + done per version, meaning that every time a new version is installed alembic has to run, therefore this test + implements that behaviour by upgrading per revision_id instead of upgrade from scratch to the latest. + """ + def _test_alembic(app, db, module_name, downgrade_target="base"): + """Test alembic recipes for a concrete module.""" + def _sort_revision_ids(scripts_list): + """Sorts the scripts based on the previous and next revisions and returns a list of sorted revision ids.""" + revision_ids = [] + for script in scripts_list: + if script.down_revision not in revision_ids: + revision_ids.insert(0, script.revision) + elif script.down_revision in revision_ids: + down_revision_index = revision_ids.index(script.down_revision) + revision_ids.insert(down_revision_index + 1, script.revision) + elif script.nextrev in revision_ids: + next_revision_index = revision_ids.index(script.nextrev) + revision_ids.insert(next_revision_index, script.revision) + else: + revision_ids.append(script.revision) + return revision_ids + + ext = app.extensions["invenio-db"] + + with app.app_context(): + if db.engine.name == "sqlite": + raise pytest.skip("Upgrades are not supported on SQLite.") + + assert not ext.alembic.compare_metadata() + db.drop_all() + drop_alembic_version_table() + module_scripts = [] + for script in ext.alembic.log(): + if module_name in script.branch_labels: + module_scripts.append(script) + revision_ids = _sort_revision_ids(module_scripts) + for revision_id in revision_ids: + ext.alembic.upgrade(target=revision_id) + + assert not ext.alembic.compare_metadata() + ext.alembic.downgrade(target=downgrade_target) + for revision_id in revision_ids: + ext.alembic.upgrade(target=revision_id) + + assert not ext.alembic.compare_metadata() + + return _test_alembic diff --git a/pytest_invenio/plugin.py b/pytest_invenio/plugin.py index 9f15f3f..724af23 100644 --- a/pytest_invenio/plugin.py +++ b/pytest_invenio/plugin.py @@ -50,6 +50,7 @@ script_info, search, search_clear, + test_alembic, )