diff --git a/tests/conftest.py b/tests/conftest.py index 956adec..a4b5de7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from fed_mng.models import ( SLA, + IdentityProvider, Provider, Region, ResourceUsage, @@ -16,6 +17,7 @@ UserGroupManager, ) from tests.item_data import ( + identity_provider_dict, provider_dict, region_dict, request_dict, @@ -140,3 +142,13 @@ def db_sla(db_session: Session, db_negotiation: SLANegotiation) -> SLA: db_session.commit() db_session.refresh(db_sla) return db_sla + + +@pytest.fixture(scope="function") +def db_identity_provider(db_session: Session) -> IdentityProvider: + data = identity_provider_dict() + db_identity_provider = IdentityProvider(**data) + db_session.add(db_identity_provider) + db_session.commit() + db_session.refresh(db_identity_provider) + return db_identity_provider diff --git a/tests/item_data.py b/tests/item_data.py index c1866b3..fbf7633 100644 --- a/tests/item_data.py +++ b/tests/item_data.py @@ -1,6 +1,8 @@ from datetime import datetime from random import randint +from pydantic import AnyHttpUrl + from tests.utils import ( random_email, random_lower_string, @@ -63,3 +65,7 @@ def location_dict() -> dict[str, str]: def sla_dict() -> dict[str, str]: start_date, end_date = random_start_end_dates() return {"start_date": start_date, "end_date": end_date} + + +def identity_provider_dict() -> dict[str, AnyHttpUrl]: + return {"endpoint": random_url()} diff --git a/tests/test_models.py b/tests/test_models.py index d4a7031..5787fa6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,6 +4,7 @@ import pytest from fed_reg.provider.enum import ProviderStatus +from pydantic import AnyHttpUrl from pytest_cases import case, get_case_tags, parametrize, parametrize_with_cases from sqlalchemy.exc import IntegrityError from sqlmodel import Session @@ -12,6 +13,7 @@ from fed_mng.models import ( SLA, Admin, + IdentityProvider, Location, Provider, Region, @@ -23,6 +25,7 @@ TotBlockStorageQuota, TotComputeQuota, TotNetworkQuota, + Trusts, User, UserBlockStorageQuota, UserComputeQuota, @@ -32,6 +35,7 @@ from tests.item_data import ( block_storage_dict, compute_dict, + identity_provider_dict, location_dict, network_dict, provider_dict, @@ -66,6 +70,14 @@ def case_status(self) -> dict[str, Any]: return {**sla_dict(), "status": random_sla_status()} +class CaseIdentityProviderData: + def case_short(self) -> dict[str, AnyHttpUrl]: + return identity_provider_dict() + + def case_desc(self) -> dict[str, Any]: + return {**identity_provider_dict(), "description": random_lower_string()} + + class CaseResourceUsageData: def case_short(self) -> dict[str, datetime]: return request_dict() @@ -231,6 +243,13 @@ def test_resource_usage_request( assert resource_usage_request.moderator_id is None assert resource_usage_request.moderator is None + assert resource_usage_request.tot_block_storage_quota is None + assert resource_usage_request.tot_compute_quota is None + assert resource_usage_request.tot_network_quota is None + assert resource_usage_request.user_block_storage_quota is None + assert resource_usage_request.user_compute_quota is None + assert resource_usage_request.user_network_quota is None + def test_resource_usage_request_without_issuer(db_session: Session) -> None: data = request_dict() @@ -402,7 +421,142 @@ def test_provider_with_site_admins( assert db_provider.id == db_site_admin.providers[0].id -# TODO: Test provider with versions? +# TODO: Test Provider federation requests +# TODO: Test provider with versions and mentioning request + + +@parametrize_with_cases("data", cases=CaseIdentityProviderData) +def test_identity_provider(db_session: Session, data: dict[str, Any]) -> None: + identity_provider = IdentityProvider(**data) + + db_session.add(identity_provider) + db_session.commit() + db_session.refresh(identity_provider) + + assert identity_provider.id is not None + assert identity_provider.endpoint == data.get("endpoint") + assert identity_provider.description == data.get("description") + assert len(identity_provider.authorized_providers) == 0 + + +def test_identity_provider_with_provider( + db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider +) -> None: + assert len(db_identity_provider.authorized_providers) == 0 + assert len(db_provider.trusted_identity_providers) == 0 + + # Create the trust association object and link it + trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string()) + trust.provider = db_provider + db_identity_provider.authorized_providers.append(trust) + + db_session.add(db_identity_provider) + db_session.commit() + db_session.refresh(db_identity_provider) + + assert len(db_identity_provider.authorized_providers) == 1 + assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id + assert ( + db_identity_provider.authorized_providers[0].identity_provider_id + == db_identity_provider.id + ) + assert len(db_provider.trusted_identity_providers) == 1 + assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id + assert ( + db_provider.trusted_identity_providers[0].identity_provider_id + == db_identity_provider.id + ) + + +def test_identity_provider_with_multi_providers( + db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider +) -> None: + assert len(db_identity_provider.authorized_providers) == 0 + assert len(db_provider.trusted_identity_providers) == 0 + + # Create the trust association object and link it + trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string()) + trust.provider = db_provider + db_identity_provider.authorized_providers.append(trust) + + db_provider2 = Provider(**provider_dict()) + trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string()) + trust.provider = db_provider2 + db_identity_provider.authorized_providers.append(trust) + + db_session.add(db_identity_provider) + db_session.commit() + db_session.refresh(db_identity_provider) + + assert len(db_identity_provider.authorized_providers) == 2 + assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id + assert db_identity_provider.authorized_providers[1].provider_id == db_provider2.id + assert ( + db_identity_provider.authorized_providers[0].identity_provider_id + == db_identity_provider.id + ) + assert ( + db_identity_provider.authorized_providers[1].identity_provider_id + == db_identity_provider.id + ) + + assert len(db_provider.trusted_identity_providers) == 1 + assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id + assert ( + db_provider.trusted_identity_providers[0].identity_provider_id + == db_identity_provider.id + ) + assert db_provider2.trusted_identity_providers[0].provider_id == db_provider2.id + assert ( + db_provider2.trusted_identity_providers[0].identity_provider_id + == db_identity_provider.id + ) + + +def test_provider_with_multi_identity_providers( + db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider +) -> None: + assert len(db_identity_provider.authorized_providers) == 0 + assert len(db_provider.trusted_identity_providers) == 0 + + # Create the trust association object and link it + trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string()) + trust.provider = db_provider + db_identity_provider.authorized_providers.append(trust) + + db_identity_provider2 = IdentityProvider(**identity_provider_dict()) + trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string()) + trust.provider = db_provider + db_identity_provider2.authorized_providers.append(trust) + + db_session.add(db_provider) + db_session.commit() + db_session.refresh(db_provider) + + assert len(db_identity_provider.authorized_providers) == 1 + assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id + assert ( + db_identity_provider.authorized_providers[0].identity_provider_id + == db_identity_provider.id + ) + assert len(db_identity_provider2.authorized_providers) == 1 + assert db_identity_provider2.authorized_providers[0].provider_id == db_provider.id + assert ( + db_identity_provider2.authorized_providers[0].identity_provider_id + == db_identity_provider2.id + ) + + assert len(db_provider.trusted_identity_providers) == 2 + assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id + assert db_provider.trusted_identity_providers[1].provider_id == db_provider.id + assert ( + db_provider.trusted_identity_providers[0].identity_provider_id + == db_identity_provider.id + ) + assert ( + db_provider.trusted_identity_providers[1].identity_provider_id + == db_identity_provider2.id + ) @parametrize_with_cases("data", cases=CaseRegionData) @@ -499,10 +653,6 @@ def test_invalid_location(data: dict[str, Any]) -> None: assert location.__getattribute__(k) is None -# TODO: Test idp -# TODO: Test provider with idps - - def test_negotiation( db_session: Session, db_provider: Provider,