diff --git a/invenio_config/env.py b/invenio_config/env.py index 3cf1f9e..c431b94 100644 --- a/invenio_config/env.py +++ b/invenio_config/env.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2018 CERN. +# Copyright (C) 2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -45,3 +46,102 @@ def init_app(self, app): # Set value app.config[varname] = value + + +def _get_env_var(prefix, keys): + """Retrieve environment variables with a given prefix.""" + return {k: os.environ.get(f"{prefix}_{k.upper()}") for k in keys} + + +def build_db_uri(): + """ + Build database URI from environment variables or use default. + + Priority order: + 1. INVENIO_SQLALCHEMY_DATABASE_URI + 2. SQLALCHEMY_DATABASE_URI + 3. INVENIO_DB_* specific environment variables + 4. Default URI + + Note: For option 3, to assert that the INVENIO_DB_* settings take effect, + you need to set SQLALCHEMY_DATABASE_URI="" in your environment. + """ + default_uri = "postgresql+psycopg2://invenio-app-rdm:invenio-app-rdm@localhost/invenio-app-rdm" + + uri = os.environ.get("INVENIO_SQLALCHEMY_DATABASE_URI") or os.environ.get( + "SQLALCHEMY_DATABASE_URI" + ) + if uri: + return uri + + db_params = _get_env_var( + "INVENIO_DB", ["user", "password", "host", "port", "name", "protocol"] + ) + if all(db_params.values()): + uri = f"{db_params['protocol']}://{db_params['user']}:{db_params['password']}@{db_params['host']}:{db_params['port']}/{db_params['name']}" + return uri + + return default_uri + + +def build_broker_url(): + """ + Build broker URL from environment variables or use default. + + Priority order: + 1. INVENIO_BROKER_URL + 2. BROKER_URL + 3. INVENIO_RABBITMQ_* specific environment variables + 4. Default URL + Note: see: https://docs.celeryq.dev/en/stable/userguide/configuration.html#new-lowercase-settings + """ + default_url = "amqp://guest:guest@localhost:5672/" + + borker_url = os.environ.get("INVENIO_BROKER_URL") or os.environ.get("BROKER_URL") + if borker_url: + return borker_url + + broker_params = _get_env_var( + "INVENIO_RABBITMQ", ["user", "password", "host", "port", "protocol"] + ) + if all(broker_params.values()): + vhost = f"{os.environ.get("INVENIO_RABBITMQ_VHOST").lstrip("/")}" + amq_url = f"{broker_params['protocol']}://{broker_params['user']}:{broker_params['password']}@{broker_params['host']}:{broker_params['port']}/{vhost}" + return amq_url + return default_url + + +def build_redis_url(db=None): + """ + Build Redis URL from environment variables or use default. + + Priority order: + 1. INVENIO_CACHE_REDIS_URL + 2. CACHE_REDIS_URL + 3. INVENIO_CACHE_REDIS_* specific environment variables + 4. Default URL + """ + db = db if db is not None else 0 + default_url = f"redis://localhost:6379/{db}" + + redis_url = os.environ.get("INVENIO_CACHE_REDIS_URL") or os.environ.get( + "CACHE_REDIS_URL" + ) + if redis_url and redis_url.startswith(("redis://", "rediss://", "unix://")): + return redis_url + + redis_params = _get_env_var( + "INVENIO_CACHE_REDIS", ["host", "port", "password", "protocol"] + ) + + if redis_params["host"] and redis_params["port"]: + protocol = redis_params.get("protocol", "redis") + password = ( + f":{redis_params['password']}@" if redis_params.get("password") else "" + ) + redis_url = ( + f"{protocol}://{password}{redis_params['host']}:{redis_params['port']}/{db}" + ) + return redis_url + + return default_url diff --git a/tests/test_invenio_config.py b/tests/test_invenio_config.py index 4461648..08c5f00 100644 --- a/tests/test_invenio_config.py +++ b/tests/test_invenio_config.py @@ -2,6 +2,7 @@ # # This file is part of Invenio. # Copyright (C) 2015-2018 CERN. +# Copyright (C) 2024 KTH Royal Institute of Technology. # # Invenio is free software; you can redistribute it and/or modify it # under the terms of the MIT License; see LICENSE file for more details. @@ -16,6 +17,7 @@ import warnings from os.path import join +import pytest from flask import Flask from mock import patch from pkg_resources import EntryPoint @@ -29,6 +31,7 @@ create_config_loader, ) from invenio_config.default import ALLOWED_HTML_ATTRS, ALLOWED_HTML_TAGS +from invenio_config.env import build_broker_url, build_db_uri, build_redis_url class ConfigEP(EntryPoint): @@ -231,3 +234,189 @@ class Config(object): assert app.config["ENV"] == "env" finally: shutil.rmtree(tmppath) + + +def set_env_vars(monkeypatch, env_vars): + """Helper function to set environment variables.""" + for key in env_vars: + monkeypatch.delenv(key, raising=False) + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + +@pytest.mark.parametrize( + "env_vars, expected_uri", + [ + ( + { + "INVENIO_DB_USER": "testuser", + "INVENIO_DB_PASSWORD": "testpassword", + "INVENIO_DB_HOST": "testhost", + "INVENIO_DB_PORT": "5432", + "INVENIO_DB_NAME": "testdb", + "INVENIO_DB_PROTOCOL": "postgresql+psycopg2", + }, + "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb", + ), + ( + { + "INVENIO_SQLALCHEMY_DATABASE_URI": "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb" + }, + "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb", + ), + ( + { + "SQLALCHEMY_DATABASE_URI": "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb" + }, + "postgresql+psycopg2://testuser:testpassword@testhost:5432/testdb", + ), + ( + {}, + "postgresql+psycopg2://invenio-app-rdm:invenio-app-rdm@localhost/invenio-app-rdm", + ), + ], +) +def test_build_db_uri(monkeypatch, env_vars, expected_uri): + """Test building database URI.""" + set_env_vars(monkeypatch, env_vars) + assert build_db_uri() == expected_uri + + +@pytest.mark.parametrize( + "env_vars, expected_url", + [ + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + "INVENIO_RABBITMQ_VHOST": "/testvhost", + }, + "amqp://testuser:testpassword@testhost:5672/testvhost", + ), + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + "INVENIO_RABBITMQ_VHOST": "testvhost", + }, + "amqp://testuser:testpassword@testhost:5672/testvhost", + ), + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + "INVENIO_RABBITMQ_VHOST": "", + }, + "amqp://testuser:testpassword@testhost:5672/", + ), + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + }, + "amqp://testuser:testpassword@testhost:5672/", + ), + ], +) +def test_build_broker_url_with_vhost(monkeypatch, env_vars, expected_url): + """Test building broker URL with vhost.""" + set_env_vars(monkeypatch, env_vars) + assert build_broker_url() == expected_url + + +@pytest.mark.parametrize( + "env_vars, expected_url", + [ + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + "INVENIO_RABBITMQ_VHOST": "/testvhost", + }, + "amqp://testuser:testpassword@testhost:5672/testvhost", + ), + ( + { + "INVENIO_RABBITMQ_USER": "testuser", + "INVENIO_RABBITMQ_PASSWORD": "testpassword", + "INVENIO_RABBITMQ_HOST": "testhost", + "INVENIO_RABBITMQ_PORT": "5672", + "INVENIO_RABBITMQ_PROTOCOL": "amqp", + "INVENIO_RABBITMQ_VHOST": "testvhost", + }, + "amqp://testuser:testpassword@testhost:5672/testvhost", + ), + ( + {"INVENIO_BROKER_URL": "amqp://guest:guest@localhost:5672/"}, + "amqp://guest:guest@localhost:5672/", + ), + ( + {}, + "amqp://guest:guest@localhost:5672/", + ), + ], +) +def test_build_broker_url_with_vhost(monkeypatch, env_vars, expected_url): + """Test building broker URL with vhost.""" + set_env_vars(monkeypatch, env_vars) + assert build_broker_url() == expected_url + + +@pytest.mark.parametrize( + "env_vars, db, expected_url", + [ + ( + { + "INVENIO_CACHE_REDIS_HOST": "testhost", + "INVENIO_CACHE_REDIS_PORT": "6379", + "INVENIO_CACHE_REDIS_PASSWORD": "testpassword", + "INVENIO_CACHE_REDIS_PROTOCOL": "redis", + }, + 2, + "redis://:testpassword@testhost:6379/2", + ), + ( + { + "INVENIO_CACHE_REDIS_HOST": "testhost", + "INVENIO_CACHE_REDIS_PORT": "6379", + "INVENIO_CACHE_REDIS_PROTOCOL": "redis", + }, + 1, + "redis://testhost:6379/1", + ), + ( + {"BROKER_URL": "redis://localhost:6379/0"}, + None, + "redis://localhost:6379/0", + ), + ( + {"INVENIO_CACHE_REDIS_URL": "redis://localhost:6379/3"}, + 3, + "redis://localhost:6379/3", + ), + ( + {}, + 4, + "redis://localhost:6379/4", + ), + ], +) +def test_build_redis_url(monkeypatch, env_vars, db, expected_url): + """Test building Redis URL.""" + set_env_vars(monkeypatch, env_vars) + assert build_redis_url(db=db) == expected_url