Skip to content

Commit

Permalink
fix(variant_study_service): correct comments, add middleware for curr…
Browse files Browse the repository at this point in the history
…ent user identity, update execute_or_add_commands function, update tests
  • Loading branch information
maugde committed Jan 7, 2025
1 parent e8e407d commit 6ee154b
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 219 deletions.
4 changes: 2 additions & 2 deletions antarest/login/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,14 +406,14 @@ def get_bot(self, id: int, params: RequestParameters) -> Bot:

def get_bot_info(self, id: int, params: RequestParameters) -> Optional[BotIdentityDTO]:
"""
Get user informations
Get user information
Permission: SADMIN, GADMIN (own group), USER (own user)
Args:
id: bot id
params: request parameters
Returns: bot informations and roles
Returns: bot information and roles
"""
bot = self.get_bot(id, params)
Expand Down
42 changes: 42 additions & 0 deletions antarest/login/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import typing as t
from contextvars import ContextVar
from optparse import Option
from typing import Optional

from starlette.requests import Request

from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response
from starlette.types import ASGIApp

from antarest.core.jwt import JWTUser
from antarest.core.serialization import from_json
from antarest.fastapi_jwt_auth import AuthJWT

_current_user: ContextVar[t.Optional[AuthJWT]] = ContextVar("_current_user", default=None)

class CurrentUserMiddleware(BaseHTTPMiddleware):
def __init__(self, app: t.Optional[ASGIApp]) -> None:
super().__init__(app)

async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
global _current_user
auth_jwt = AuthJWT(Request(request.scope))
_current_user.set(auth_jwt)

response = await call_next(request)
return response


def get_current_user() -> Optional[JWTUser]:
auth_jwt = _current_user.get()
if auth_jwt:
json_data = from_json(auth_jwt.get_jwt_subject())
current_user = JWTUser.model_validate(json_data)
else:
current_user = None
return current_user
5 changes: 4 additions & 1 deletion antarest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from antarest.core.swagger import customize_openapi
from antarest.core.tasks.model import cancel_orphan_tasks
from antarest.core.utils.fastapi_sqlalchemy import DBSessionMiddleware
from antarest.login.utils import CurrentUserMiddleware
from antarest.core.utils.utils import get_local_path
from antarest.core.utils.web import tags_metadata
from antarest.fastapi_jwt_auth import AuthJWT
Expand Down Expand Up @@ -139,7 +140,7 @@ async def set_default_executor(app: FastAPI) -> AsyncGenerator[None, None]:
# Database
engine = init_db_engine(config_file, config, auto_upgrade_db)
application.add_middleware(DBSessionMiddleware, custom_engine=engine, session_args=SESSION_ARGS)
# Since Starlette Version 0.24.0, the middlewares are lazily build inside this function
# Since Starlette Version 0.24.0, the middlewares are lazily built inside this function
# But we need to instantiate this middleware as it's needed for the study service.
# So we manually instantiate it here.
DBSessionMiddleware(None, custom_engine=engine, session_args=cast(Dict[str, bool], SESSION_ARGS))
Expand Down Expand Up @@ -274,6 +275,8 @@ def handle_all_exception(request: Request, exc: Exception) -> Any:
config=RATE_LIMIT_CONFIG,
)

application.add_middleware(CurrentUserMiddleware)

init_admin_user(engine=engine, session_args=SESSION_ARGS, admin_password=config.security.admin_pwd)
services = create_services(config, app_ctxt)

Expand Down
19 changes: 12 additions & 7 deletions antarest/study/business/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from antares.study.version import StudyVersion

from antarest.core.exceptions import CommandApplicationError
from antarest.core.jwt import DEFAULT_ADMIN_USER
from antarest.core.jwt import JWTUser
from antarest.core.requests import RequestParameters
from antarest.core.serialization import AntaresBaseModel
from antarest.login.utils import get_current_user
from antarest.study.business.all_optional_meta import camel_case_model
from antarest.study.model import RawStudy, Study
from antarest.study.storage.rawstudy.model.filesystem.factory import FileStudy
Expand All @@ -32,12 +33,16 @@


def execute_or_add_commands(
study: Study,
file_study: FileStudy,
commands: t.Sequence[ICommand],
storage_service: StudyStorageService,
listener: t.Optional[ICommandListener] = None,
study: Study,
file_study: FileStudy,
commands: t.Sequence[ICommand],
storage_service: StudyStorageService,
listener: t.Optional[ICommandListener] = None,
jwt_user: JWTUser = None,
) -> None:
# get current user if not in session, otherwise get session user
current_user = get_current_user() or jwt_user

if isinstance(study, RawStudy):
executed_commands: t.MutableSequence[ICommand] = []
for command in commands:
Expand Down Expand Up @@ -68,7 +73,7 @@ def execute_or_add_commands(
storage_service.variant_study_service.append_commands(
study.id,
transform_command_to_dto(commands, force_aggregate=True),
RequestParameters(user=DEFAULT_ADMIN_USER),
RequestParameters(user=current_user),
)


Expand Down
12 changes: 11 additions & 1 deletion antarest/study/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,13 @@ def __init__(
repository: StudyMetadataRepository,
storage_service: StudyStorageService,
event_bus: IEventBus,
jwt_user: JWTUser,
):
self._study_id = _study_id
self.repository = repository
self.storage_service = storage_service
self.event_bus = event_bus
self.jwt_user = jwt_user

def _generate_timeseries(self, notifier: ITaskNotifier) -> None:
"""Run the task (lock the database)."""
Expand All @@ -245,7 +247,14 @@ def _generate_timeseries(self, notifier: ITaskNotifier) -> None:
command = GenerateThermalClusterTimeSeries(
command_context=command_context, study_version=file_study.config.version
)
execute_or_add_commands(study, file_study, [command], self.storage_service, listener)
execute_or_add_commands(
study,
file_study,
[command],
self.storage_service,
listener,
self.jwt_user
)

if isinstance(study, VariantStudy):
# In this case we only added the command to the list.
Expand Down Expand Up @@ -2581,6 +2590,7 @@ def generate_timeseries(self, study: Study, params: RequestParameters) -> str:
repository=self.repository,
storage_service=self.storage_service,
event_bus=self.event_bus,
jwt_user=params.user
)

return self.task_service.add_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_command(self, study_id: str, command_id: str, params: RequestParameters)

def get_commands(self, study_id: str, params: RequestParameters) -> t.List[CommandDTOAPI]:
"""
Get command lists
Get commands list
Args:
study_id: study id
params: request parameters
Expand Down
7 changes: 7 additions & 0 deletions tests/core/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,12 @@ def test_ts_generation_task(
raw_study_path = tmp_path / "study"

regular_user = User(id=99, name="regular")
jwt_user = Mock(
spec=JWTUser,
id=regular_user.id,
type="user",
impersonator=regular_user.id,
)
db.session.add(regular_user)
db.session.commit()

Expand Down Expand Up @@ -563,6 +569,7 @@ def test_ts_generation_task(
repository=study_service.repository,
storage_service=study_service.storage_service,
event_bus=study_service.event_bus,
jwt_user=jwt_user
)

task_id = study_service.task_service.add_task(
Expand Down
11 changes: 1 addition & 10 deletions tests/integration/test_integration_token_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#
# This file is part of the Antares project.

import datetime
import io
import typing as t
from unittest.mock import ANY
Expand Down Expand Up @@ -178,15 +177,7 @@ def test_nominal_case_of_an_api_user(client: TestClient, admin_access_token: str
commands_res = client.get(f"/v1/studies/{variant_id}/commands", headers=bot_headers)

for command in commands_res.json():
# FIXME: Some commands, such as those that modify study configurations, are run by admin user
# Thus the `user_name` for such type of command will be the admin's name
# Here we detect those commands by their `action` and their `target` values
if command["action"] == "update_playlist" or (
command["action"] == "update_config" and "settings/generaldata" in command["args"]["target"]
):
assert command["user_name"] == "admin"
else:
assert command["user_name"] == "admin_bot"
assert command["user_name"] == "admin_bot"
assert command["updated_at"]

# generate variant before running a simulation
Expand Down
Loading

0 comments on commit 6ee154b

Please sign in to comment.