Skip to content

Commit

Permalink
Add IDP tests
Browse files Browse the repository at this point in the history
  • Loading branch information
giosava94 committed Feb 15, 2024
1 parent bce937b commit 383ddfc
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 5 deletions.
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fed_mng.models import (
SLA,
IdentityProvider,
Provider,
Region,
ResourceUsage,
Expand All @@ -16,6 +17,7 @@
UserGroupManager,
)
from tests.item_data import (
identity_provider_dict,
provider_dict,
region_dict,
request_dict,
Expand Down Expand Up @@ -140,3 +142,13 @@ def db_sla(db_session: Session, db_negotiation: SLANegotiation) -> SLA:
db_session.commit()
db_session.refresh(db_sla)
return db_sla


@pytest.fixture(scope="function")
def db_identity_provider(db_session: Session) -> IdentityProvider:
data = identity_provider_dict()
db_identity_provider = IdentityProvider(**data)
db_session.add(db_identity_provider)
db_session.commit()
db_session.refresh(db_identity_provider)
return db_identity_provider
6 changes: 6 additions & 0 deletions tests/item_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from random import randint

from pydantic import AnyHttpUrl

from tests.utils import (
random_email,
random_lower_string,
Expand Down Expand Up @@ -63,3 +65,7 @@ def location_dict() -> dict[str, str]:
def sla_dict() -> dict[str, str]:
start_date, end_date = random_start_end_dates()
return {"start_date": start_date, "end_date": end_date}


def identity_provider_dict() -> dict[str, AnyHttpUrl]:
return {"endpoint": random_url()}
160 changes: 155 additions & 5 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from fed_reg.provider.enum import ProviderStatus
from pydantic import AnyHttpUrl
from pytest_cases import case, get_case_tags, parametrize, parametrize_with_cases
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session
Expand All @@ -12,6 +13,7 @@
from fed_mng.models import (
SLA,
Admin,
IdentityProvider,
Location,
Provider,
Region,
Expand All @@ -23,6 +25,7 @@
TotBlockStorageQuota,
TotComputeQuota,
TotNetworkQuota,
Trusts,
User,
UserBlockStorageQuota,
UserComputeQuota,
Expand All @@ -32,6 +35,7 @@
from tests.item_data import (
block_storage_dict,
compute_dict,
identity_provider_dict,
location_dict,
network_dict,
provider_dict,
Expand Down Expand Up @@ -66,6 +70,14 @@ def case_status(self) -> dict[str, Any]:
return {**sla_dict(), "status": random_sla_status()}


class CaseIdentityProviderData:
def case_short(self) -> dict[str, AnyHttpUrl]:
return identity_provider_dict()

def case_desc(self) -> dict[str, Any]:
return {**identity_provider_dict(), "description": random_lower_string()}


class CaseResourceUsageData:
def case_short(self) -> dict[str, datetime]:
return request_dict()
Expand Down Expand Up @@ -231,6 +243,13 @@ def test_resource_usage_request(
assert resource_usage_request.moderator_id is None
assert resource_usage_request.moderator is None

assert resource_usage_request.tot_block_storage_quota is None
assert resource_usage_request.tot_compute_quota is None
assert resource_usage_request.tot_network_quota is None
assert resource_usage_request.user_block_storage_quota is None
assert resource_usage_request.user_compute_quota is None
assert resource_usage_request.user_network_quota is None


def test_resource_usage_request_without_issuer(db_session: Session) -> None:
data = request_dict()
Expand Down Expand Up @@ -402,7 +421,142 @@ def test_provider_with_site_admins(
assert db_provider.id == db_site_admin.providers[0].id


# TODO: Test provider with versions?
# TODO: Test Provider federation requests
# TODO: Test provider with versions and mentioning request


@parametrize_with_cases("data", cases=CaseIdentityProviderData)
def test_identity_provider(db_session: Session, data: dict[str, Any]) -> None:
identity_provider = IdentityProvider(**data)

db_session.add(identity_provider)
db_session.commit()
db_session.refresh(identity_provider)

assert identity_provider.id is not None
assert identity_provider.endpoint == data.get("endpoint")
assert identity_provider.description == data.get("description")
assert len(identity_provider.authorized_providers) == 0


def test_identity_provider_with_provider(
db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider
) -> None:
assert len(db_identity_provider.authorized_providers) == 0
assert len(db_provider.trusted_identity_providers) == 0

# Create the trust association object and link it
trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string())
trust.provider = db_provider
db_identity_provider.authorized_providers.append(trust)

db_session.add(db_identity_provider)
db_session.commit()
db_session.refresh(db_identity_provider)

assert len(db_identity_provider.authorized_providers) == 1
assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id
assert (
db_identity_provider.authorized_providers[0].identity_provider_id
== db_identity_provider.id
)
assert len(db_provider.trusted_identity_providers) == 1
assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id
assert (
db_provider.trusted_identity_providers[0].identity_provider_id
== db_identity_provider.id
)


def test_identity_provider_with_multi_providers(
db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider
) -> None:
assert len(db_identity_provider.authorized_providers) == 0
assert len(db_provider.trusted_identity_providers) == 0

# Create the trust association object and link it
trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string())
trust.provider = db_provider
db_identity_provider.authorized_providers.append(trust)

db_provider2 = Provider(**provider_dict())
trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string())
trust.provider = db_provider2
db_identity_provider.authorized_providers.append(trust)

db_session.add(db_identity_provider)
db_session.commit()
db_session.refresh(db_identity_provider)

assert len(db_identity_provider.authorized_providers) == 2
assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id
assert db_identity_provider.authorized_providers[1].provider_id == db_provider2.id
assert (
db_identity_provider.authorized_providers[0].identity_provider_id
== db_identity_provider.id
)
assert (
db_identity_provider.authorized_providers[1].identity_provider_id
== db_identity_provider.id
)

assert len(db_provider.trusted_identity_providers) == 1
assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id
assert (
db_provider.trusted_identity_providers[0].identity_provider_id
== db_identity_provider.id
)
assert db_provider2.trusted_identity_providers[0].provider_id == db_provider2.id
assert (
db_provider2.trusted_identity_providers[0].identity_provider_id
== db_identity_provider.id
)


def test_provider_with_multi_identity_providers(
db_session: Session, db_identity_provider: IdentityProvider, db_provider: Provider
) -> None:
assert len(db_identity_provider.authorized_providers) == 0
assert len(db_provider.trusted_identity_providers) == 0

# Create the trust association object and link it
trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string())
trust.provider = db_provider
db_identity_provider.authorized_providers.append(trust)

db_identity_provider2 = IdentityProvider(**identity_provider_dict())
trust = Trusts(idp_name=random_lower_string(), protocol=random_lower_string())
trust.provider = db_provider
db_identity_provider2.authorized_providers.append(trust)

db_session.add(db_provider)
db_session.commit()
db_session.refresh(db_provider)

assert len(db_identity_provider.authorized_providers) == 1
assert db_identity_provider.authorized_providers[0].provider_id == db_provider.id
assert (
db_identity_provider.authorized_providers[0].identity_provider_id
== db_identity_provider.id
)
assert len(db_identity_provider2.authorized_providers) == 1
assert db_identity_provider2.authorized_providers[0].provider_id == db_provider.id
assert (
db_identity_provider2.authorized_providers[0].identity_provider_id
== db_identity_provider2.id
)

assert len(db_provider.trusted_identity_providers) == 2
assert db_provider.trusted_identity_providers[0].provider_id == db_provider.id
assert db_provider.trusted_identity_providers[1].provider_id == db_provider.id
assert (
db_provider.trusted_identity_providers[0].identity_provider_id
== db_identity_provider.id
)
assert (
db_provider.trusted_identity_providers[1].identity_provider_id
== db_identity_provider2.id
)


@parametrize_with_cases("data", cases=CaseRegionData)
Expand Down Expand Up @@ -499,10 +653,6 @@ def test_invalid_location(data: dict[str, Any]) -> None:
assert location.__getattribute__(k) is None


# TODO: Test idp
# TODO: Test provider with idps


def test_negotiation(
db_session: Session,
db_provider: Provider,
Expand Down

0 comments on commit 383ddfc

Please sign in to comment.