diff --git a/pantos/servicenode/database/__init__.py b/pantos/servicenode/database/__init__.py index 88bbc5a..f98af58 100644 --- a/pantos/servicenode/database/__init__.py +++ b/pantos/servicenode/database/__init__.py @@ -105,23 +105,25 @@ def initialize_package(is_flask_app: bool = False) -> None: global _session_maker _session_maker = sqlalchemy.orm.sessionmaker(bind=_sql_engine) # Initialize the tables - with _session_maker.begin() as session: - assert isinstance(session, Session) # type hint - # Blockchain table - statement = sqlalchemy.select(sqlalchemy.func.max(Blockchain_.id)) - max_blockchain_id = session.execute(statement).scalar_one_or_none() - for blockchain in sorted(Blockchain): - if (max_blockchain_id is None - or max_blockchain_id < blockchain.value): - session.add( - Blockchain_(id=blockchain.value, name=blockchain.name)) - # Transfer status table - statement = sqlalchemy.select(sqlalchemy.func.max(TransferStatus_.id)) - max_transfer_status_id = session.execute( - statement).scalar_one_or_none() - for transfer_status in sorted(TransferStatus): - if (max_transfer_status_id is None - or max_transfer_status_id < transfer_status.value): - session.add( - TransferStatus_(id=transfer_status.value, - name=transfer_status.name)) + if is_flask_app: + with _session_maker.begin() as session: + assert isinstance(session, Session) + # Blockchain table + statement = sqlalchemy.select(sqlalchemy.func.max(Blockchain_.id)) + max_blockchain_id = session.execute(statement).scalar_one_or_none() + for blockchain in sorted(Blockchain): + if (max_blockchain_id is None + or max_blockchain_id < blockchain.value): + session.add( + Blockchain_(id=blockchain.value, name=blockchain.name)) + # Transfer status table + statement = sqlalchemy.select( + sqlalchemy.func.max(TransferStatus_.id)) + max_transfer_status_id = session.execute( + statement).scalar_one_or_none() + for transfer_status in sorted(TransferStatus): + if (max_transfer_status_id is None + or max_transfer_status_id < transfer_status.value): + session.add( + TransferStatus_(id=transfer_status.value, + name=transfer_status.name)) diff --git a/tests/database/test_initialize_package.py b/tests/database/test_initialize_package.py index b18f307..cfb06a8 100644 --- a/tests/database/test_initialize_package.py +++ b/tests/database/test_initialize_package.py @@ -15,11 +15,12 @@ TransferStatus as TransferStatus_ +@pytest.mark.parametrize('is_flask_app', [True, False]) @unittest.mock.patch('pantos.servicenode.database.Blockchain', Blockchain) @unittest.mock.patch('pantos.servicenode.database.config') @unittest.mock.patch('pantos.servicenode.database.sqlalchemy.create_engine') def test_initialize_package_blockchain_correct(mocked_create_engine, - mocked_config, + mocked_config, is_flask_app, embedded_db_engine, db_clean_session): mocked_create_engine.return_value = embedded_db_engine @@ -36,22 +37,26 @@ def test_initialize_package_blockchain_correct(mocked_create_engine, mocked_config.__getitem__.side_effect = mocked_config_dict.__getitem__ blockchains = [blockchain for blockchain in sorted(Blockchain)] - initialize_package() + initialize_package(is_flask_app) blockchains_in_db = db_clean_session.execute( sqlalchemy.select(Blockchain_)).fetchall() - assert len(blockchains_in_db) == len(blockchains) - for (blockchain_in_db, blockchain) in zip(blockchains_in_db, blockchains): - assert blockchain.value == blockchain_in_db[0].id - assert blockchain.name == blockchain_in_db[0].name + assert len(blockchains_in_db) == (len(blockchains) if is_flask_app else 0) + if is_flask_app: + for (blockchain_in_db, blockchain) in zip(blockchains_in_db, + blockchains): + assert blockchain.value == blockchain_in_db[0].id + assert blockchain.name == blockchain_in_db[0].name +@pytest.mark.parametrize('is_flask_app', [True, False]) @unittest.mock.patch('pantos.servicenode.database.TransferStatus', TransferStatus) @unittest.mock.patch('pantos.servicenode.database.config') @unittest.mock.patch('pantos.servicenode.database.sqlalchemy.create_engine') def test_initialize_package_transfer_status_correct(mocked_create_engine, mocked_config, + is_flask_app, embedded_db_engine, db_clean_session): mocked_create_engine.return_value = embedded_db_engine @@ -70,35 +75,18 @@ def test_initialize_package_transfer_status_correct(mocked_create_engine, transfer_status for transfer_status in sorted(TransferStatus) ] - initialize_package() + initialize_package(is_flask_app) transfer_statuses_in_db = db_clean_session.execute( sqlalchemy.select(TransferStatus_)).fetchall() - assert len(transfer_statuses_in_db) == len(transfer_statuses) - for (transfer_status_in_db, - transfer_status) in zip(transfer_statuses_in_db, transfer_statuses): - assert transfer_status.value == transfer_status_in_db[0].id - assert transfer_status.name == transfer_status_in_db[0].name - - -@unittest.mock.patch('pantos.servicenode.database.sqlalchemy.func.max', - side_effect=Exception) -@unittest.mock.patch('pantos.servicenode.database.config') -@unittest.mock.patch('pantos.servicenode.database.sqlalchemy.create_engine') -def test_initialize_package_raises_error(mocked_create_engine, mocked_config, - mock_sorted): - mocked_config_dict = { - 'database': { - 'url': '', - 'echo': '', - 'pool_size': '', - 'max_overflow': '' - } - } - mocked_config.__getitem__.side_effect = mocked_config_dict.__getitem__ - - with pytest.raises(Exception): - initialize_package() + assert len(transfer_statuses_in_db) == (len(transfer_statuses) + if is_flask_app else 0) + if is_flask_app: + for (transfer_status_in_db, + transfer_status) in zip(transfer_statuses_in_db, + transfer_statuses): + assert transfer_status.value == transfer_status_in_db[0].id + assert transfer_status.name == transfer_status_in_db[0].name @unittest.mock.patch('pantos.servicenode.database._session_maker', 'session')