From 6ee154b8e99ff1864f859df119a1553683f0c30b Mon Sep 17 00:00:00 2001 From: Maurane GLAUDE Date: Tue, 7 Jan 2025 14:10:14 +0100 Subject: [PATCH] fix(variant_study_service): correct comments, add middleware for current user identity, update execute_or_add_commands function, update tests --- antarest/login/service.py | 4 +- antarest/login/utils.py | 42 ++ antarest/main.py | 5 +- antarest/study/business/utils.py | 19 +- antarest/study/service.py | 12 +- .../variantstudy/variant_study_service.py | 2 +- tests/core/test_tasks.py | 7 + .../test_integration_token_end_to_end.py | 11 +- .../storage/business/test_arealink_manager.py | 381 +++++++++--------- .../test_variant_study_service.py | 15 +- 10 files changed, 279 insertions(+), 219 deletions(-) create mode 100644 antarest/login/utils.py diff --git a/antarest/login/service.py b/antarest/login/service.py index d3103a6dfc..fef02af68f 100644 --- a/antarest/login/service.py +++ b/antarest/login/service.py @@ -406,14 +406,14 @@ def get_bot(self, id: int, params: RequestParameters) -> Bot: def get_bot_info(self, id: int, params: RequestParameters) -> Optional[BotIdentityDTO]: """ - Get user informations + Get user information Permission: SADMIN, GADMIN (own group), USER (own user) Args: id: bot id params: request parameters - Returns: bot informations and roles + Returns: bot information and roles """ bot = self.get_bot(id, params) diff --git a/antarest/login/utils.py b/antarest/login/utils.py new file mode 100644 index 0000000000..13b35b8df7 --- /dev/null +++ b/antarest/login/utils.py @@ -0,0 +1,42 @@ +import typing as t +from contextvars import ContextVar +from optparse import Option +from typing import Optional + +from starlette.requests import Request + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.responses import Response +from starlette.types import ASGIApp + +from antarest.core.jwt import JWTUser +from antarest.core.serialization import from_json +from antarest.fastapi_jwt_auth import AuthJWT + +_current_user: ContextVar[t.Optional[AuthJWT]] = ContextVar("_current_user", default=None) + +class CurrentUserMiddleware(BaseHTTPMiddleware): + def __init__(self, app: t.Optional[ASGIApp]) -> None: + super().__init__(app) + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + global _current_user + auth_jwt = AuthJWT(Request(request.scope)) + _current_user.set(auth_jwt) + + response = await call_next(request) + return response + + +def get_current_user() -> Optional[JWTUser]: + auth_jwt = _current_user.get() + if auth_jwt: + json_data = from_json(auth_jwt.get_jwt_subject()) + current_user = JWTUser.model_validate(json_data) + else: + current_user = None + return current_user diff --git a/antarest/main.py b/antarest/main.py index 51440a8560..6c292dbec4 100644 --- a/antarest/main.py +++ b/antarest/main.py @@ -41,6 +41,7 @@ from antarest.core.swagger import customize_openapi from antarest.core.tasks.model import cancel_orphan_tasks from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware +from antarest.login.utils import CurrentUserMiddleware from antarest.core.utils.utils import get_local_path from antarest.core.utils.web import tags_metadata from antarest.fastapi_jwt_auth import AuthJWT @@ -139,7 +140,7 @@ async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]: # Database engine = init_db_engine(config_file, config, auto_upgrade_db) application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS) - # Since Starlette Version 0.24.0, the middlewares are lazily build inside this function + # Since Starlette Version 0.24.0, the middlewares are lazily built inside this function # But we need to instantiate this middleware as it's needed for the study service. # So we manually instantiate it here. DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS)) @@ -274,6 +275,8 @@ def handle_all_exception(request: Request, exc: Exception) -> Any: config=RATE_LIMIT_CONFIG, ) + application.add_middleware(CurrentUserMiddleware) + init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd) services = create_services(config, app_ctxt) diff --git a/antarest/study/business/utils.py b/antarest/study/business/utils.py index 76671f2eec..b2d4446c6b 100644 --- a/antarest/study/business/utils.py +++ b/antarest/study/business/utils.py @@ -15,9 +15,10 @@ from antares.study.version import StudyVersion from antarest.core.exceptions import CommandApplicationError -from antarest.core.jwt import DEFAULT_ADMIN_USER +from antarest.core.jwt import JWTUser from antarest.core.requests import RequestParameters from antarest.core.serialization import AntaresBaseModel +from antarest.login.utils import get_current_user from antarest.study.business.all_optional_meta import camel_case_model from antarest.study.model import RawStudy, Study from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy @@ -32,12 +33,16 @@ def execute_or_add_commands( - study: Study, - file_study: FileStudy, - commands: t.Sequence[ICommand], - storage_service: StudyStorageService, - listener: t.Optional[ICommandListener] = None, + study: Study, + file_study: FileStudy, + commands: t.Sequence[ICommand], + storage_service: StudyStorageService, + listener: t.Optional[ICommandListener] = None, + jwt_user: JWTUser = None, ) -> None: + # get current user if not in session, otherwise get session user + current_user = get_current_user() or jwt_user + if isinstance(study, RawStudy): executed_commands: t.MutableSequence[ICommand] = [] for command in commands: @@ -68,7 +73,7 @@ def execute_or_add_commands( storage_service.variant_study_service.append_commands( study.id, transform_command_to_dto(commands, force_aggregate=True), - RequestParameters(user=DEFAULT_ADMIN_USER), + RequestParameters(user=current_user), ) diff --git a/antarest/study/service.py b/antarest/study/service.py index d394056f95..a0dbbb4b62 100644 --- a/antarest/study/service.py +++ b/antarest/study/service.py @@ -229,11 +229,13 @@ def __init__( repository: StudyMetadataRepository, storage_service: StudyStorageService, event_bus: IEventBus, + jwt_user: JWTUser, ): self._study_id = _study_id self.repository = repository self.storage_service = storage_service self.event_bus = event_bus + self.jwt_user = jwt_user def _generate_timeseries(self, notifier: ITaskNotifier) -> None: """Run the task (lock the database).""" @@ -245,7 +247,14 @@ def _generate_timeseries(self, notifier: ITaskNotifier) -> None: command = GenerateThermalClusterTimeSeries( command_context=command_context, study_version=file_study.config.version ) - execute_or_add_commands(study, file_study, [command], self.storage_service, listener) + execute_or_add_commands( + study, + file_study, + [command], + self.storage_service, + listener, + self.jwt_user + ) if isinstance(study, VariantStudy): # In this case we only added the command to the list. @@ -2581,6 +2590,7 @@ def generate_timeseries(self, study: Study, params: RequestParameters) -> str: repository=self.repository, storage_service=self.storage_service, event_bus=self.event_bus, + jwt_user=params.user ) return self.task_service.add_task( diff --git a/antarest/study/storage/variantstudy/variant_study_service.py b/antarest/study/storage/variantstudy/variant_study_service.py index 82fed1d98e..683ba1a5db 100644 --- a/antarest/study/storage/variantstudy/variant_study_service.py +++ b/antarest/study/storage/variantstudy/variant_study_service.py @@ -141,7 +141,7 @@ def get_command(self, study_id: str, command_id: str, params: RequestParameters) def get_commands(self, study_id: str, params: RequestParameters) -> t.List[CommandDTOAPI]: """ - Get command lists + Get commands list Args: study_id: study id params: request parameters diff --git a/tests/core/test_tasks.py b/tests/core/test_tasks.py index 5a8b490047..7f728b5736 100644 --- a/tests/core/test_tasks.py +++ b/tests/core/test_tasks.py @@ -476,6 +476,12 @@ def test_ts_generation_task( raw_study_path = tmp_path / "study" regular_user = User(id=99, name="regular") + jwt_user = Mock( + spec=JWTUser, + id=regular_user.id, + type="user", + impersonator=regular_user.id, + ) db.session.add(regular_user) db.session.commit() @@ -563,6 +569,7 @@ def test_ts_generation_task( repository=study_service.repository, storage_service=study_service.storage_service, event_bus=study_service.event_bus, + jwt_user=jwt_user ) task_id = study_service.task_service.add_task( diff --git a/tests/integration/test_integration_token_end_to_end.py b/tests/integration/test_integration_token_end_to_end.py index b6184622b1..e57b0d4039 100644 --- a/tests/integration/test_integration_token_end_to_end.py +++ b/tests/integration/test_integration_token_end_to_end.py @@ -10,7 +10,6 @@ # # This file is part of the Antares project. -import datetime import io import typing as t from unittest.mock import ANY @@ -178,15 +177,7 @@ def test_nominal_case_of_an_api_user(client: TestClient, admin_access_token: str commands_res = client.get(f"/v1/studies/{variant_id}/commands", headers=bot_headers) for command in commands_res.json(): - # FIXME: Some commands, such as those that modify study configurations, are run by admin user - # Thus the `user_name` for such type of command will be the admin's name - # Here we detect those commands by their `action` and their `target` values - if command["action"] == "update_playlist" or ( - command["action"] == "update_config" and "settings/generaldata" in command["args"]["target"] - ): - assert command["user_name"] == "admin" - else: - assert command["user_name"] == "admin_bot" + assert command["user_name"] == "admin_bot" assert command["updated_at"] # generate variant before running a simulation diff --git a/tests/storage/business/test_arealink_manager.py b/tests/storage/business/test_arealink_manager.py index ce372dccba..1b12c9a5e3 100644 --- a/tests/storage/business/test_arealink_manager.py +++ b/tests/storage/business/test_arealink_manager.py @@ -13,7 +13,7 @@ import json import uuid from pathlib import Path -from unittest.mock import Mock +from unittest.mock import Mock, patch from zipfile import ZipFile import pytest @@ -122,205 +122,206 @@ def test_area_crud(empty_study: FileStudy, matrix_service: SimpleMatrixService): ) assert len(empty_study.config.areas.keys()) == 0 - area_manager.create_area(study, AreaCreationDTO(name="test", type=AreaType.AREA)) - assert len(empty_study.config.areas.keys()) == 1 - assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] is None - - area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100))) - assert empty_study.tree.get(["input", "areas", "test", "ui", "ui"]) == { - "x": 100, - "y": 200, - "color_r": 255, - "color_g": 0, - "color_b": 100, - "layers": 0, - } - - area_manager.create_area(study, AreaCreationDTO(name="test2", type=AreaType.AREA)) - - link_manager.create_link( - study, - LinkDTO( - area1="test", - area2="test2", - ), - ) - assert empty_study.config.areas["test"].links.get("test2") is not None + with patch("antarest.study.business.utils.get_current_user", return_value=DEFAULT_ADMIN_USER): + area_manager.create_area(study, AreaCreationDTO(name="test", type=AreaType.AREA)) + assert len(empty_study.config.areas.keys()) == 1 + assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] is None + + area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100))) + assert empty_study.tree.get(["input", "areas", "test", "ui", "ui"]) == { + "x": 100, + "y": 200, + "color_r": 255, + "color_g": 0, + "color_b": 100, + "layers": 0, + } - link_manager.delete_link(study, "test", "test2") - assert empty_study.config.areas["test"].links.get("test2") is None - area_manager.delete_area(study, "test") - area_manager.delete_area(study, "test2") - assert len(empty_study.config.areas.keys()) == 0 + area_manager.create_area(study, AreaCreationDTO(name="test2", type=AreaType.AREA)) - # Check `AreaManager` behaviour with a variant study - variant_id = str(uuid.uuid4()) - # noinspection PyArgumentList - study = VariantStudy( - id=variant_id, - path=str(empty_study.config.study_path), - additional_data=StudyAdditionalData(), - version="820", - ) - variant_study_service.get_raw.return_value = empty_study - area_manager.create_area( - study, - AreaCreationDTO(name="test", type=AreaType.AREA, metadata=PatchArea(country="FR")), - ) - variant_study_service.append_commands.assert_called_with( - variant_id, - [CommandDTO(action=CommandName.CREATE_AREA.value, args={"area_name": "test"}, study_version=study_version)], - RequestParameters(DEFAULT_ADMIN_USER), - ) - assert (empty_study.config.study_path / "patch.json").exists() - assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] == "FR" - - area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100))) - variant_study_service.append_commands.assert_called_with( - variant_id, - [ - CommandDTO( - id=None, - action=CommandName.UPDATE_CONFIG.value, - args=[ - { - "target": "input/areas/test/ui/ui/x", - "data": 100, - }, - { - "target": "input/areas/test/ui/ui/y", - "data": 200, - }, - { - "target": "input/areas/test/ui/ui/color_r", - "data": 255, - }, - { - "target": "input/areas/test/ui/ui/color_g", - "data": 0, - }, - { - "target": "input/areas/test/ui/ui/color_b", - "data": 100, - }, - { - "target": "input/areas/test/ui/layerX/0", - "data": 100, - }, - { - "target": "input/areas/test/ui/layerY/0", - "data": 200, - }, - { - "target": "input/areas/test/ui/layerColor/0", - "data": "255,0,100", - }, - ], - study_version=study_version, + link_manager.create_link( + study, + LinkDTO( + area1="test", + area2="test2", ), - ], - RequestParameters(DEFAULT_ADMIN_USER), - ) - - area_manager.create_area(study, AreaCreationDTO(name="test2", type=AreaType.AREA)) - link_manager.create_link( - study, - LinkDTO( - area1="test", - area2="test2", - ), - ) - variant_study_service.append_commands.assert_called_with( - variant_id, - [ - CommandDTO( - action=CommandName.CREATE_LINK.value, - args={ - "area1": "test", - "area2": "test2", - "parameters": { + ) + assert empty_study.config.areas["test"].links.get("test2") is not None + + link_manager.delete_link(study, "test", "test2") + assert empty_study.config.areas["test"].links.get("test2") is None + area_manager.delete_area(study, "test") + area_manager.delete_area(study, "test2") + assert len(empty_study.config.areas.keys()) == 0 + + # Check `AreaManager` behaviour with a variant study + variant_id = str(uuid.uuid4()) + # noinspection PyArgumentList + study = VariantStudy( + id=variant_id, + path=str(empty_study.config.study_path), + additional_data=StudyAdditionalData(), + version="820", + ) + variant_study_service.get_raw.return_value = empty_study + area_manager.create_area( + study, + AreaCreationDTO(name="test", type=AreaType.AREA, metadata=PatchArea(country="FR")), + ) + variant_study_service.append_commands.assert_called_with( + variant_id, + [CommandDTO(action=CommandName.CREATE_AREA.value, args={"area_name": "test"}, study_version=study_version)], + RequestParameters(DEFAULT_ADMIN_USER), + ) + assert (empty_study.config.study_path / "patch.json").exists() + assert json.loads((empty_study.config.study_path / "patch.json").read_text())["areas"]["test"]["country"] == "FR" + + area_manager.update_area_ui(study, "test", UpdateAreaUi(x=100, y=200, color_rgb=(255, 0, 100))) + variant_study_service.append_commands.assert_called_with( + variant_id, + [ + CommandDTO( + id=None, + action=CommandName.UPDATE_CONFIG.value, + args=[ + { + "target": "input/areas/test/ui/ui/x", + "data": 100, + }, + { + "target": "input/areas/test/ui/ui/y", + "data": 200, + }, + { + "target": "input/areas/test/ui/ui/color_r", + "data": 255, + }, + { + "target": "input/areas/test/ui/ui/color_g", + "data": 0, + }, + { + "target": "input/areas/test/ui/ui/color_b", + "data": 100, + }, + { + "target": "input/areas/test/ui/layerX/0", + "data": 100, + }, + { + "target": "input/areas/test/ui/layerY/0", + "data": 200, + }, + { + "target": "input/areas/test/ui/layerColor/0", + "data": "255,0,100", + }, + ], + study_version=study_version, + ), + ], + RequestParameters(DEFAULT_ADMIN_USER), + ) + + area_manager.create_area(study, AreaCreationDTO(name="test2", type=AreaType.AREA)) + link_manager.create_link( + study, + LinkDTO( + area1="test", + area2="test2", + ), + ) + variant_study_service.append_commands.assert_called_with( + variant_id, + [ + CommandDTO( + action=CommandName.CREATE_LINK.value, + args={ "area1": "test", "area2": "test2", - "hurdles_cost": False, - "loop_flow": False, - "use_phase_shifter": False, - "transmission_capacities": TransmissionCapacity.ENABLED, - "asset_type": AssetType.AC, - "display_comments": True, - "comments": "", - "colorr": 112, - "colorg": 112, - "colorb": 112, - "link_width": 1.0, - "link_style": LinkStyle.PLAIN, - "filter_synthesis": "hourly, daily, weekly, monthly, annual", - "filter_year_by_year": "hourly, daily, weekly, monthly, annual", + "parameters": { + "area1": "test", + "area2": "test2", + "hurdles_cost": False, + "loop_flow": False, + "use_phase_shifter": False, + "transmission_capacities": TransmissionCapacity.ENABLED, + "asset_type": AssetType.AC, + "display_comments": True, + "comments": "", + "colorr": 112, + "colorg": 112, + "colorb": 112, + "link_width": 1.0, + "link_style": LinkStyle.PLAIN, + "filter_synthesis": "hourly, daily, weekly, monthly, annual", + "filter_year_by_year": "hourly, daily, weekly, monthly, annual", + }, }, - }, - study_version=study_version, + study_version=study_version, + ), + ], + RequestParameters(DEFAULT_ADMIN_USER), + ) + + study.version = 810 + link_manager.create_link( + study, + LinkDTO( + area1="test", + area2="test2", ), - ], - RequestParameters(DEFAULT_ADMIN_USER), - ) - - study.version = 810 - link_manager.create_link( - study, - LinkDTO( - area1="test", - area2="test2", - ), - ) - variant_study_service.append_commands.assert_called_with( - variant_id, - [ - CommandDTO( - action=CommandName.CREATE_LINK.value, - args={ - "area1": "test", - "area2": "test2", - "parameters": { + ) + variant_study_service.append_commands.assert_called_with( + variant_id, + [ + CommandDTO( + action=CommandName.CREATE_LINK.value, + args={ "area1": "test", "area2": "test2", - "hurdles_cost": False, - "loop_flow": False, - "use_phase_shifter": False, - "transmission_capacities": TransmissionCapacity.ENABLED, - "asset_type": AssetType.AC, - "display_comments": True, - "comments": "", - "colorr": 112, - "colorg": 112, - "colorb": 112, - "link_width": 1.0, - "link_style": LinkStyle.PLAIN, + "parameters": { + "area1": "test", + "area2": "test2", + "hurdles_cost": False, + "loop_flow": False, + "use_phase_shifter": False, + "transmission_capacities": TransmissionCapacity.ENABLED, + "asset_type": AssetType.AC, + "display_comments": True, + "comments": "", + "colorr": 112, + "colorg": 112, + "colorb": 112, + "link_width": 1.0, + "link_style": LinkStyle.PLAIN, + }, }, - }, - study_version=study_version, - ), - ], - RequestParameters(DEFAULT_ADMIN_USER), - ) - link_manager.delete_link(study, "test", "test2") - variant_study_service.append_commands.assert_called_with( - variant_id, - [ - CommandDTO( - action=CommandName.REMOVE_LINK.value, - args={"area1": "test", "area2": "test2"}, - study_version=study_version, - ), - ], - RequestParameters(DEFAULT_ADMIN_USER), - ) - area_manager.delete_area(study, "test2") - variant_study_service.append_commands.assert_called_with( - variant_id, - [ - CommandDTO(action=CommandName.REMOVE_AREA.value, args={"id": "test2"}, study_version=study_version), - ], - RequestParameters(DEFAULT_ADMIN_USER), - ) + study_version=study_version, + ), + ], + RequestParameters(DEFAULT_ADMIN_USER), + ) + link_manager.delete_link(study, "test", "test2") + variant_study_service.append_commands.assert_called_with( + variant_id, + [ + CommandDTO( + action=CommandName.REMOVE_LINK.value, + args={"area1": "test", "area2": "test2"}, + study_version=study_version, + ), + ], + RequestParameters(DEFAULT_ADMIN_USER), + ) + area_manager.delete_area(study, "test2") + variant_study_service.append_commands.assert_called_with( + variant_id, + [ + CommandDTO(action=CommandName.REMOVE_AREA.value, args={"id": "test2"}, study_version=study_version), + ], + RequestParameters(DEFAULT_ADMIN_USER), + ) def test_get_all_area(): diff --git a/tests/study/storage/variantstudy/test_variant_study_service.py b/tests/study/storage/variantstudy/test_variant_study_service.py index d5c03f5224..9edf6241b1 100644 --- a/tests/study/storage/variantstudy/test_variant_study_service.py +++ b/tests/study/storage/variantstudy/test_variant_study_service.py @@ -14,7 +14,7 @@ import re import typing from pathlib import Path -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np import pytest @@ -212,12 +212,13 @@ def test_generate_task( study_version=study_version, ) - execute_or_add_commands( - variant_study, - file_study, - commands=[create_area_fr, create_st_storage], - storage_service=study_storage_service, - ) + with patch("antarest.study.business.utils.get_current_user", return_value=DEFAULT_ADMIN_USER): + execute_or_add_commands( + variant_study, + file_study, + commands=[create_area_fr, create_st_storage], + storage_service=study_storage_service, + ) ## Run the "generate" task actual_uui = variant_study_service.generate_task(