diff --git a/src/api/apps.py b/src/api/apps.py index 2e0c8cf..6c81518 100644 --- a/src/api/apps.py +++ b/src/api/apps.py @@ -16,9 +16,9 @@ class AppsController: service: MockAppService = Depends(deps.get_app_service) @router.get("/") - def get_all(self) -> List[AppRegistration]: - return self.service.get_all() + async def get_all(self) -> List[AppRegistration]: + return await self.service.get_all() @router.get("/{app_id}") - def get_by(self, app_id: str) -> AppRegistration: - return self.service.get_by(app_id) + async def get_by(self, app_id: str) -> AppRegistration: + return await self.service.get_by(app_id) diff --git a/src/api/deps.py b/src/api/deps.py index a092e57..7a1542d 100644 --- a/src/api/deps.py +++ b/src/api/deps.py @@ -1,5 +1,5 @@ from azure.identity import EnvironmentCredential -from msgraph.core import GraphClient +from msgraph import GraphServiceClient from core.config import settings from services.app_service import AppService @@ -11,6 +11,6 @@ def get_app_service() -> AppService: return AzureAppService(get_client()) if settings.azure_enabled else MockAppService() -def get_client() -> GraphClient: +def get_client() -> GraphServiceClient: creds = EnvironmentCredential() - return GraphClient(credential=creds) + return GraphServiceClient(credentials=creds) diff --git a/src/requirements.txt b/src/requirements.txt index 93eec66..8debac4 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,10 +1,10 @@ -azure-identity~=1.9.0 +azure-identity~=1.12.0 python-dateutil~=2.8.2 fastapi~=0.75.2 fastapi-health~=0.4.0 fastapi-utils~=0.2.1 json-logging~=1.3.0 -msgraph-core~=0.2.2 +msgraph-sdk~=1.4.0 pydantic~=1.10.0 starlette-exporter~=0.12.0 uvicorn~=0.17.1 diff --git a/src/services/app_service.py b/src/services/app_service.py index 1346b92..c024e53 100644 --- a/src/services/app_service.py +++ b/src/services/app_service.py @@ -6,11 +6,11 @@ class AppService: @abstractmethod - def get_all(self) -> List[AppRegistration]: + async def get_all(self) -> List[AppRegistration]: pass @abstractmethod - def get_by(self, app_id: str) -> AppRegistration: + async def get_by(self, app_id: str) -> AppRegistration: pass diff --git a/src/services/azure_app_service.py b/src/services/azure_app_service.py index 037356f..3409388 100644 --- a/src/services/azure_app_service.py +++ b/src/services/azure_app_service.py @@ -1,9 +1,10 @@ from datetime import datetime -from typing import List, Dict, Optional +from typing import List, Optional -import dateutil.parser -from fastapi import HTTPException -from msgraph.core import GraphClient +from msgraph import GraphServiceClient +from msgraph.generated.models.application import Application +from msgraph.generated.models.key_credential import KeyCredential +from msgraph.generated.models.password_credential import PasswordCredential from prometheus_client import Gauge from models import AppRegistration @@ -13,7 +14,7 @@ APP_EXPIRY = Gauge( "azure_app_earliest_expiry", "Returns earliest credential expiry in unix time (seconds)", - ["app_id","app_name"] + ["app_id", "app_name"] ) APP_CREDS_EXPIRY = Gauge( @@ -22,39 +23,43 @@ ["app_id", "app_name", "credential_name"] ) + class AzureAppService(AppService): - def __init__(self, client: GraphClient): + def __init__(self, client: GraphServiceClient): self.client = client - def get_all(self) -> List[AppRegistration]: - result = self.client.get("/applications") - if not result.ok: - raise HTTPException(status_code=result.status_code) - value = result.json()['value'] - apps = [AzureAppService._map_app(a) for a in value] + async def get_all(self) -> List[AppRegistration]: + result = await self.client.applications.get() + apps = [] + while result is not None: + apps += [AzureAppService._map_app(a) for a in result.value] + if result.odata_next_link is None: + break + result = await self.client.applications.with_url(result.odata_next_link).get() + self.observe(apps) return apps - def get_by(self, app_id: str) -> AppRegistration: - result = self.client.get(f"/applications?$filter=appId eq '{app_id}'") - if not result.ok: - raise HTTPException(status_code=result.status_code) - return AzureAppService._map_app(result.json()['value'][0]) + async def get_by(self, app_id: str) -> AppRegistration: + result = await self.client.applications.by_application_id(app_id).get() + if result is not None: + return AzureAppService._map_app(result) + else: + raise "Application with app id %s not found." % app_id @staticmethod - def _map_app(dct: Dict) -> AppRegistration: - app_id = dct['appId'] - name = dct['displayName'] - creds = [AzureAppService._map_cred(c) for c in dct['passwordCredentials']+dct['keyCredentials']] + def _map_app(app: Application) -> AppRegistration: + app_id = app.app_id + name = app.display_name + creds = [AzureAppService._map_cred(c) for c in app.password_credentials + app.key_credentials] return AppRegistration(id=app_id, name=name, credentials=creds) @staticmethod - def _map_cred(dct: Dict) -> Credential: - # https://stackoverflow.com/a/71778150/2592915 + def _map_cred(cred: KeyCredential | PasswordCredential) -> Credential: return Credential( - name=dct['displayName'], - created=dateutil.parser.isoparse(dct['startDateTime']), - expires=dateutil.parser.isoparse(dct['endDateTime']) + name=cred.display_name, + created=cred.start_date_time, + expires=cred.end_date_time ) @staticmethod @@ -65,4 +70,4 @@ def observe(apps: List[AppRegistration]): if expiry: APP_EXPIRY.labels(app_id=app.id, app_name=app.name).set(int(expiry.timestamp())) for cred in app.credentials: - APP_CREDS_EXPIRY.labels(app_id=app.id, app_name=app.name, credential_name=cred.name).set(int(cred.expires.timestamp())) + APP_CREDS_EXPIRY.labels(app_id=app.id, app_name=app.name, credential_name=cred.name).set(int(cred.expires.timestamp())) diff --git a/src/services/mock_app_service.py b/src/services/mock_app_service.py index 5689f43..49daa9c 100644 --- a/src/services/mock_app_service.py +++ b/src/services/mock_app_service.py @@ -14,8 +14,8 @@ def __init__(self, apps: Dict[str, AppRegistration] = None): else: self.apps[MockAppService.SOME_APP.id] = MockAppService.SOME_APP - def get_all(self) -> List[AppRegistration]: + async def get_all(self) -> List[AppRegistration]: return list(self.apps.values()) - def get_by(self, app_id: str) -> AppRegistration: + async def get_by(self, app_id: str) -> AppRegistration: return self.apps[app_id]